add diffusion progress
This commit is contained in:
@@ -6,6 +6,9 @@ from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import cv2
|
||||
import socketio
|
||||
import asyncio
|
||||
from socketio import AsyncServer
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
|
||||
global_sio: AsyncServer = None
|
||||
|
||||
|
||||
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
|
||||
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
||||
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
||||
|
||||
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
||||
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
||||
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
||||
return {}
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, config: ApiConfig):
|
||||
self.app = app
|
||||
@@ -134,6 +150,12 @@ class Api:
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
global global_sio
|
||||
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
||||
self.app.mount("/ws", self.combined_asgi_app)
|
||||
global_sio = self.sio
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
@@ -206,6 +228,9 @@ class Api:
|
||||
quality=self.config.quality,
|
||||
infos=infos,
|
||||
)
|
||||
|
||||
asyncio.run(self.sio.emit("diffusion_finish"))
|
||||
|
||||
return Response(
|
||||
content=res_img_bytes,
|
||||
media_type=f"image/{ext}",
|
||||
@@ -246,10 +271,10 @@ class Api:
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
self.combined_asgi_app,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
timeout_keep_alive=60,
|
||||
timeout_keep_alive=999999999,
|
||||
)
|
||||
|
||||
def _build_file_manager(self) -> Optional[FileManager]:
|
||||
@@ -290,6 +315,7 @@ class Api:
|
||||
disable_nsfw=self.config.disable_nsfw_checker,
|
||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||
cpu_offload=self.config.cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user