This commit is contained in:
Qing
2023-12-19 13:16:30 +08:00
parent f27fc51e34
commit 141936a937
18 changed files with 479 additions and 358 deletions

View File

@@ -31,6 +31,20 @@ class ControlNet(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
@property
def lcm_lora_id(self):
if self.model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
]:
return "latent-consistency/lcm-lora-sdv1-5"
if self.model_info.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return "latent-consistency/lcm-lora-sdxl"
raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"]
@@ -72,7 +86,7 @@ class ControlNet(DiffusionInpaintModel):
)
controlnet = ControlNetModel.from_pretrained(
sd_controlnet_method, torch_dtype=torch_dtype
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
)
if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@@ -81,7 +95,7 @@ class ControlNet(DiffusionInpaintModel):
model_kwargs["num_in_channels"] = 9
self.model = PipeClass.from_single_file(
model_info.path, controlnet=controlnet
model_info.path, controlnet=controlnet, **model_kwargs
).to(torch_dtype)
else:
self.model = PipeClass.from_pretrained(