add lcm lora
This commit is contained in:
@@ -13,7 +13,8 @@ from lama_cleaner.helper import (
|
||||
switch_mps_device,
|
||||
)
|
||||
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
@@ -381,3 +382,11 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
# original_pixel_indices
|
||||
# ]
|
||||
return inpaint_result
|
||||
|
||||
def set_scheduler(self, config: Config):
|
||||
scheduler_config = self.model.scheduler.config
|
||||
sd_sampler = config.sd_sampler
|
||||
if config.sd_lcm_lora:
|
||||
sd_sampler = SDSampler.lcm
|
||||
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.schema import Config, SDSampler
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper:
|
||||
@@ -67,6 +67,7 @@ def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True):
|
||||
class SD(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
|
||||
@@ -129,10 +130,7 @@ class SD(DiffusionInpaintModel):
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
|
||||
scheduler_config = self.model.scheduler.config
|
||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
self.set_scheduler(config)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
|
||||
@@ -13,6 +13,7 @@ class SDXL(DiffusionInpaintModel):
|
||||
name = "sdxl"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers.pipelines import AutoPipelineForInpainting
|
||||
@@ -56,10 +57,7 @@ class SDXL(DiffusionInpaintModel):
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
|
||||
scheduler_config = self.model.scheduler.config
|
||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
self.set_scheduler(config)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
@@ -80,7 +78,7 @@ class SDXL(DiffusionInpaintModel):
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
Reference in New Issue
Block a user