add diffusion progress

This commit is contained in:
Qing
2024-01-02 17:13:11 +08:00
parent f38be37f8c
commit 6253016019
17 changed files with 239 additions and 42 deletions

View File

@@ -6,6 +6,9 @@ from pathlib import Path
from typing import Optional, Dict, List
import cv2
import socketio
import asyncio
from socketio import AsyncServer
import torch
import numpy as np
from loguru import logger
@@ -109,6 +112,19 @@ def api_middleware(app: FastAPI):
app.add_middleware(CORSMiddleware, **cors_options)
global_sio: AsyncServer = None
def diffuser_callback(pipe, step: int, timestep: int, callback_kwargs: Dict):
# self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict
# logger.info(f"diffusion callback: step={step}, timestep={timestep}")
# We use asyncio loos for task processing. Perhaps in the future, we can add a processing queue similar to InvokeAI,
# but for now let's just start a separate event loop. It shouldn't make a difference for single person use
asyncio.run(global_sio.emit("diffusion_progress", {"step": step}))
return {}
class Api:
def __init__(self, app: FastAPI, config: ApiConfig):
self.app = app
@@ -134,6 +150,12 @@ class Api:
self.app.mount("/", StaticFiles(directory=WEB_APP_DIR, html=True), name="assets")
# fmt: on
global global_sio
self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.combined_asgi_app = socketio.ASGIApp(self.sio, self.app)
self.app.mount("/ws", self.combined_asgi_app)
global_sio = self.sio
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
@@ -206,6 +228,9 @@ class Api:
quality=self.config.quality,
infos=infos,
)
asyncio.run(self.sio.emit("diffusion_finish"))
return Response(
content=res_img_bytes,
media_type=f"image/{ext}",
@@ -246,10 +271,10 @@ class Api:
def launch(self):
self.app.include_router(self.router)
uvicorn.run(
self.app,
self.combined_asgi_app,
host=self.config.host,
port=self.config.port,
timeout_keep_alive=60,
timeout_keep_alive=999999999,
)
def _build_file_manager(self) -> Optional[FileManager]:
@@ -290,6 +315,7 @@ class Api:
disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload,
callback=diffuser_callback,
)

View File

@@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
model_info = kwargs["model_info"]
model_info = kwargs["model_info"]
controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info
@@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),

View File

@@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
generator=generator,
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel):
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -2,7 +2,6 @@ import os
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import AutoencoderKL
from loguru import logger
@@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel):
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs):
try:
return func(**kwargs)
except ValueError as e:
# 处理异常的逻辑
if "You are trying to load the model files of the `variant=fp16`" in str(e):
logger.info("variant=fp16 not found, try revision=fp16")
return func(**{**kwargs, "variant": None, "revision": "fp16"})

View File

@@ -18,7 +18,6 @@ def test_model_switch():
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch("lama")
@@ -34,7 +33,6 @@ def test_controlnet_switch_onoff(caplog):
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch_controlnet_method(
@@ -59,7 +57,6 @@ def test_switch_controlnet_method(caplog):
disable_nsfw=True,
sd_cpu_textencoder=True,
cpu_offload=False,
callback=None,
)
model.switch_controlnet_method(

View File

@@ -30,15 +30,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a dog sitting on a bench in the park",
@@ -72,15 +68,11 @@ def test_outpainting(name, device, rect):
def test_kandinsky_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a cat",
@@ -117,15 +109,11 @@ def test_kandinsky_outpainting(name, device, rect):
def test_powerpaint_outpainting(name, device, rect):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name=name,
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
prompt="a dog sitting on a bench in the park",

View File

@@ -18,15 +18,11 @@ from lama_cleaner.tests.test_model import get_config, assert_equal
def test_sdxl(device, strategy, sampler):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
strategy=strategy,
@@ -54,15 +50,11 @@ def test_sdxl(device, strategy, sampler):
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
sd_steps = check_device(device)
def callback(i, t, latents):
pass
model = ModelManager(
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
device=torch.device(device),
disable_nsfw=True,
sd_cpu_textencoder=False,
callback=callback,
)
cfg = get_config(
strategy=strategy,