controlnet support load local ckpt

This commit is contained in:
Qing
2023-04-01 09:23:37 +08:00
parent 5fd253b07e
commit 65f12b490a
3 changed files with 141 additions and 7 deletions

View File

@@ -1,3 +1,5 @@
import gc
import PIL.Image
import cv2
import numpy as np
@@ -41,6 +43,38 @@ NAMES_MAP = {
}
def load_from_local_model(local_model_path, torch_dtype, controlnet):
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .pipeline import StableDiffusionControlNetInpaintPipeline
logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
local_model_path,
num_in_channels=9,
from_safetensors=local_model_path.endswith("safetensors"),
device="cpu",
)
inpaint_pipe = StableDiffusionControlNetInpaintPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
controlnet=controlnet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
del pipe
gc.collect()
return inpaint_pipe.to(torch_dtype)
class ControlNet(DiffusionInpaintModel):
name = "controlnet"
pad_mod = 8
@@ -71,13 +105,20 @@ class ControlNet(DiffusionInpaintModel):
controlnet = ControlNetModel.from_pretrained(
f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
)
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
model_id,
controlnet=controlnet,
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs,
)
if kwargs.get("sd_local_model_path", None):
self.model = load_from_local_model(
kwargs["sd_local_model_path"],
torch_dtype=torch_dtype,
controlnet=controlnet,
)
else:
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
model_id,
controlnet=controlnet,
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()