wip: controlnet

This commit is contained in:
Qing
2023-05-11 21:51:58 +08:00
parent e5ac6a105a
commit 87f54bb87e
10 changed files with 117 additions and 29 deletions

View File

@@ -4,9 +4,7 @@ import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import (
ControlNetModel,
)
from diffusers import ControlNetModel
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
@@ -75,7 +73,7 @@ def load_from_local_model(
num_in_channels=4 if is_native_control_inpaint else 9,
from_safetensors=local_model_path.endswith("safetensors"),
device="cpu",
load_safety_checker=False
load_safety_checker=False,
)
inpaint_pipe = pipe_class(
@@ -92,7 +90,7 @@ def load_from_local_model(
del pipe
gc.collect()
return inpaint_pipe.to(torch_dtype)
return inpaint_pipe.to(torch_dtype=torch_dtype)
class ControlNet(DiffusionInpaintModel):
@@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel):
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
sd_controlnet_method = kwargs["sd_controlnet_method"]
self.sd_controlnet_method = sd_controlnet_method
if sd_controlnet_method == "control_v11p_sd15_inpaint":
from diffusers import StableDiffusionControlNetPipeline as PipeClass
@@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel):
output_type="np.array",
).images[0]
else:
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate(
[canny_image, canny_image, canny_image], axis=2
)
canny_image = PIL.Image.fromarray(canny_image)
if "canny" in self.sd_controlnet_method:
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate(
[canny_image, canny_image, canny_image], axis=2
)
canny_image = PIL.Image.fromarray(canny_image)
control_image = canny_image
elif "openpose" in self.sd_controlnet_method:
from controlnet_aux import OpenposeDetector
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
control_image = processor(image, hand_and_face=True)
else:
raise NotImplementedError(
f"{self.sd_controlnet_method} not implemented"
)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image)
output = self.model(
image=image,
control_image=canny_image,
control_image=control_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask_image=mask_image,