add diffusion progress

This commit is contained in:
Qing
2024-01-02 17:13:11 +08:00
parent f38be37f8c
commit 6253016019
17 changed files with 239 additions and 42 deletions

View File

@@ -6,6 +6,9 @@ from pathlib import Path
from typing import Optional, Dict, List
import cv2
import socketio
import asyncio
from socketio import AsyncServer
import torch
import numpy as np
from loguru import logger
@@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
app.add_middleware(CORSMiddleware, **cors_options)
global_sio: AsyncServer = None
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
@@ -134,6 +150,12 @@ class Api:
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
@@ -206,6 +228,9 @@ class Api:
quality=self.config.quality,
infos=infos,
)
asyncio.run(self.sio.emit("diffusion_finish"))
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
@@ -246,10 +271,10 @@ class Api:
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
self.app,
self.combined_asgi_app,
host=self.config.host,
port=self.config.port,
timeout_keep_alive=60,
timeout_keep_alive=999999999,
)
def _build_file_manager(self) -> Optional[FileManager]:
@@ -290,6 +315,7 @@ class Api:
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload,
callback=diffuser_callback,
)