add diffusion progress

This commit is contained in:
Qing
2024-01-02 17:13:11 +08:00
parent f38be37f8c
commit 6253016019
17 changed files with 239 additions and 42 deletions

View File

@@ -38,7 +38,7 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
model_info = kwargs["model_info"]
model_info = kwargs["model_info"]
controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info
@@ -154,7 +154,7 @@ class ControlNet(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),

View File

@@ -52,9 +52,8 @@ class Kandinsky(DiffusionInpaintModel):
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
generator=generator,
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -83,11 +83,10 @@ class SD(DiffusionInpaintModel):
strength=config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -2,7 +2,6 @@ import os
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import AutoencoderKL
from loguru import logger
@@ -79,11 +78,10 @@ class SDXL(DiffusionInpaintModel):
strength=0.999 if config.sd_strength == 1.0 else config.sd_strength,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@@ -977,7 +977,6 @@ def handle_from_pretrained_exceptions(func, **kwargs):
try:
return func(**kwargs)
except ValueError as e:
# 处理异常的逻辑
if "You are trying to load the model files of the `variant=fp16`" in str(e):
logger.info("variant=fp16 not found, try revision=fp16")
return func(**{**kwargs, "variant": None, "revision": "fp16"})