This commit is contained in:
Qing
2023-12-19 13:16:30 +08:00
parent f27fc51e34
commit 141936a937
18 changed files with 479 additions and 358 deletions

View File

@@ -279,15 +279,12 @@ class DiffusionInpaintModel(InpaintModel):
"""
# boxes = boxes_from_mask(mask)
if config.use_croper:
if config.croper_is_outpainting:
inpaint_result = self._do_outpainting(image, config)
else:
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(
image, mask, config
)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
crop_img, crop_mask, (l, t, r, b) = self._apply_cropper(image, mask, config)
crop_image = self._scaled_pad_forward(crop_img, crop_mask, config)
inpaint_result = image[:, :, ::-1]
inpaint_result[t:b, l:r, :] = crop_image
elif config.use_extender:
inpaint_result = self._do_outpainting(image, config)
else:
inpaint_result = self._scaled_pad_forward(image, mask, config)
@@ -297,10 +294,10 @@ class DiffusionInpaintModel(InpaintModel):
# cropper 和 image 在同一个坐标系下croper_x/y 可能为负数
# 从 image 中 crop 出 outpainting 区域
image_h, image_w = image.shape[:2]
cropper_l = config.croper_x
cropper_t = config.croper_y
cropper_r = config.croper_x + config.croper_width
cropper_b = config.croper_y + config.croper_height
cropper_l = config.extender_x
cropper_t = config.extender_y
cropper_r = config.extender_x + config.extender_width
cropper_b = config.extender_y + config.extender_height
image_l = 0
image_t = 0
image_r = image_w
@@ -356,8 +353,8 @@ class DiffusionInpaintModel(InpaintModel):
)[:, :, ::-1]
# 把 cropped_result_image 贴到 outpainting_image 上,这一步不需要 blend
paste_t = 0 if config.croper_y < 0 else config.croper_y
paste_l = 0 if config.croper_x < 0 else config.croper_x
paste_t = 0 if config.extender_y < 0 else config.extender_y
paste_l = 0 if config.extender_x < 0 else config.extender_x
outpainting_image[
paste_t : paste_t + expanded_cropped_result_image.shape[0],
@@ -397,8 +394,6 @@ class DiffusionInpaintModel(InpaintModel):
def set_scheduler(self, config: Config):
scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler
if config.sd_lcm_lora:
sd_sampler = SDSampler.lcm
scheduler = get_scheduler(sd_sampler, scheduler_config)
self.model.scheduler = scheduler

View File

@@ -31,6 +31,20 @@ class ControlNet(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
@property
def lcm_lora_id(self):
if self.model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
]:
return "latent-consistency/lcm-lora-sdv1-5"
if self.model_info.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return "latent-consistency/lcm-lora-sdxl"
raise NotImplementedError(f"Unsupported controlnet lcm model {self.model_info}")
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
model_info: ModelInfo = kwargs["model_info"]
@@ -72,7 +86,7 @@ class ControlNet(DiffusionInpaintModel):
)
controlnet = ControlNetModel.from_pretrained(
sd_controlnet_method, torch_dtype=torch_dtype
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
)
if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@@ -81,7 +95,7 @@ class ControlNet(DiffusionInpaintModel):
model_kwargs["num_in_channels"] = 9
self.model = PipeClass.from_single_file(
model_info.path, controlnet=controlnet
model_info.path, controlnet=controlnet, **model_kwargs
).to(torch_dtype)
else:
self.model = PipeClass.from_pretrained(

View File

@@ -39,7 +39,7 @@ class SDXL(DiffusionInpaintModel):
)
else:
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
)
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
self.model_id_or_path,

View File

@@ -16,6 +16,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
UniPCMultistepScheduler,
LCMScheduler
)
from lama_cleaner.schema import SDSampler
@@ -939,5 +940,7 @@ def get_scheduler(sd_sampler, scheduler_config):
return DPMSolverMultistepScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.uni_pc:
return UniPCMultistepScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.lcm:
return LCMScheduler.from_config(scheduler_config)
else:
raise ValueError(sd_sampler)