wip
This commit is contained in:
@@ -12,7 +12,7 @@ from lama_cleaner.helper import (
|
||||
pad_img_to_modulo,
|
||||
switch_mps_device,
|
||||
)
|
||||
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||
|
||||
@@ -22,6 +22,7 @@ class InpaintModel:
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
is_erase_model = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
@@ -264,6 +265,12 @@ class InpaintModel:
|
||||
|
||||
|
||||
class DiffusionInpaintModel(InpaintModel):
|
||||
def __init__(self, device, **kwargs):
|
||||
if kwargs.get("model_id_or_path"):
|
||||
# 用于自定义 diffusers 模型
|
||||
self.model_id_or_path = kwargs["model_id_or_path"]
|
||||
super().__init__(device, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user