add diffusion progress
This commit is contained in:
@@ -6,6 +6,9 @@ from pathlib import Path
|
||||
from typing import Optional, Dict, List
|
||||
|
||||
import cv2
|
||||
import socketio
|
||||
import asyncio
|
||||
from socketio import AsyncServer
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
|
||||
global_sio: AsyncServer = None
|
||||
|
||||
|
||||
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
|
||||
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
|
||||
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
|
||||
|
||||
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
|
||||
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
|
||||
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
|
||||
return {}
|
||||
|
||||
|
||||
class Api:
|
||||
def __init__(self, app: FastAPI, config: ApiConfig):
|
||||
self.app = app
|
||||
@@ -134,6 +150,12 @@ class Api:
|
||||
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
|
||||
# fmt: on
|
||||
|
||||
global global_sio
|
||||
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
|
||||
self.app.mount("/ws", self.combined_asgi_app)
|
||||
global_sio = self.sio
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
@@ -206,6 +228,9 @@ class Api:
|
||||
quality=self.config.quality,
|
||||
infos=infos,
|
||||
)
|
||||
|
||||
asyncio.run(self.sio.emit("diffusion_finish"))
|
||||
|
||||
return Response(
|
||||
content=res_img_bytes,
|
||||
media_type=f"image/{ext}",
|
||||
@@ -246,10 +271,10 @@ class Api:
|
||||
def launch(self):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
self.combined_asgi_app,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
timeout_keep_alive=60,
|
||||
timeout_keep_alive=999999999,
|
||||
)
|
||||
|
||||
def _build_file_manager(self) -> Optional[FileManager]:
|
||||
@@ -290,6 +315,7 @@ class Api:
|
||||
disable_nsfw=self.config.disable_nsfw_checker,
|
||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||
cpu_offload=self.config.cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
model_info = kwargs["model_info"]
|
||||
model_info = kwargs["model_info"]
|
||||
controlnet_method = kwargs["controlnet_method"]
|
||||
|
||||
self.model_info = model_info
|
||||
@@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
|
||||
@@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
generator=generator,
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
@@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel):
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from loguru import logger
|
||||
@@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel):
|
||||
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
@@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs):
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except ValueError as e:
|
||||
# 处理异常的逻辑
|
||||
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
||||
logger.info("variant=fp16 not found, try revision=fp16")
|
||||
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
||||
|
||||
@@ -18,7 +18,6 @@ def test_model_switch():
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
model.switch("lama")
|
||||
@@ -34,7 +33,6 @@ def test_controlnet_switch_onoff(caplog):
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
model.switch_controlnet_method(
|
||||
@@ -59,7 +57,6 @@ def test_switch_controlnet_method(caplog):
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
model.switch_controlnet_method(
|
||||
|
||||
@@ -30,15 +30,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||
def test_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
@@ -72,15 +68,11 @@ def test_outpainting(name, device, rect):
|
||||
def test_kandinsky_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a cat",
|
||||
@@ -117,15 +109,11 @@ def test_kandinsky_outpainting(name, device, rect):
|
||||
def test_powerpaint_outpainting(name, device, rect):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
prompt="a dog sitting on a bench in the park",
|
||||
|
||||
@@ -18,15 +18,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
|
||||
def test_sdxl(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
@@ -54,15 +50,11 @@ def test_sdxl(device, strategy, sampler):
|
||||
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
|
||||
sd_steps = check_device(device)
|
||||
|
||||
def callback(i, t, latents):
|
||||
pass
|
||||
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=strategy,
|
||||
|
||||
Reference in New Issue
Block a user