wip: controlnet
This commit is contained in:
@@ -4,9 +4,7 @@ import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
ControlNetModel,
|
||||
)
|
||||
from diffusers import ControlNetModel
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
@@ -75,7 +73,7 @@ def load_from_local_model(
|
||||
num_in_channels=4 if is_native_control_inpaint else 9,
|
||||
from_safetensors=local_model_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
inpaint_pipe = pipe_class(
|
||||
@@ -92,7 +90,7 @@ def load_from_local_model(
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe.to(torch_dtype)
|
||||
return inpaint_pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
@@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
|
||||
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
||||
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
||||
@@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel):
|
||||
output_type="np.array",
|
||||
).images[0]
|
||||
else:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate(
|
||||
[canny_image, canny_image, canny_image], axis=2
|
||||
)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate(
|
||||
[canny_image, canny_image, canny_image], axis=2
|
||||
)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
control_image = canny_image
|
||||
elif "openpose" in self.sd_controlnet_method:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
|
||||
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||
control_image = processor(image, hand_and_face=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.sd_controlnet_method} not implemented"
|
||||
)
|
||||
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
image = PIL.Image.fromarray(image)
|
||||
|
||||
output = self.model(
|
||||
image=image,
|
||||
control_image=canny_image,
|
||||
control_image=control_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask_image=mask_image,
|
||||
|
||||
Reference in New Issue
Block a user