This commit is contained in:
Qing
2023-12-30 23:36:44 +08:00
parent 85c3397b97
commit c4abda3942
35 changed files with 969 additions and 854 deletions

View File

@@ -14,7 +14,7 @@ from lama_cleaner.helper import (
)
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler
from lama_cleaner.schema import Config, HDStrategy, SDSampler
from lama_cleaner.schema import InpaintRequest, HDStrategy, SDSampler
class InpaintModel:
@@ -44,7 +44,7 @@ class InpaintModel:
return False
@abc.abstractmethod
def forward(self, image, mask, config: Config):
def forward(self, image, mask, config: InpaintRequest):
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W, 1] 255 为 masks 区域
@@ -56,7 +56,7 @@ class InpaintModel:
def download():
...
def _pad_forward(self, image, mask, config: Config):
def _pad_forward(self, image, mask, config: InpaintRequest):
origin_height, origin_width = image.shape[:2]
pad_image = pad_img_to_modulo(
image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size
@@ -74,7 +74,7 @@ class InpaintModel:
result, image, mask = self.forward_post_process(result, image, mask, config)
if config.sd_prevent_unmasked_area:
if config.sd_keep_unmasked_area:
mask = mask[:, :, np.newaxis]
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
return result
@@ -86,7 +86,7 @@ class InpaintModel:
return result, image, mask
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@@ -141,7 +141,7 @@ class InpaintModel:
return inpaint_result
def _crop_box(self, image, mask, box, config: Config):
def _crop_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
@@ -233,7 +233,7 @@ class InpaintModel:
return result
def _apply_cropper(self, image, mask, config: Config):
def _apply_cropper(self, image, mask, config: InpaintRequest):
img_h, img_w = image.shape[:2]
l, t, w, h = (
config.croper_x,
@@ -253,7 +253,7 @@ class InpaintModel:
crop_mask = mask[t:b, l:r]
return crop_img, crop_mask, (l, t, r, b)
def _run_box(self, image, mask, box, config: Config):
def _run_box(self, image, mask, box, config: InpaintRequest):
"""
Args:
@@ -276,7 +276,7 @@ class DiffusionInpaintModel(InpaintModel):
super().__init__(device, **kwargs)
@torch.no_grad()
def __call__(self, image, mask, config: Config):
def __call__(self, image, mask, config: InpaintRequest):
"""
images: [H, W, C] RGB, not normalized
masks: [H, W]
@@ -295,7 +295,7 @@ class DiffusionInpaintModel(InpaintModel):
return inpaint_result
def _do_outpainting(self, image, config: Config):
def _do_outpainting(self, image, config: InpaintRequest):
# cropper 和 image 在同一个坐标系下croper_x/y 可能为负数
# 从 image 中 crop 出 outpainting 区域
image_h, image_w = image.shape[:2]
@@ -368,7 +368,7 @@ class DiffusionInpaintModel(InpaintModel):
] = expanded_cropped_result_image
return outpainting_image
def _scaled_pad_forward(self, image, mask, config: Config):
def _scaled_pad_forward(self, image, mask, config: InpaintRequest):
longer_side_length = int(config.sd_scale * max(image.shape[:2]))
origin_size = image.shape[:2]
downsize_image = resize_max_size(image, size_limit=longer_side_length)
@@ -396,7 +396,7 @@ class DiffusionInpaintModel(InpaintModel):
# ]
return inpaint_result
def set_scheduler(self, config: Config):
def set_scheduler(self, config: InpaintRequest):
scheduler_config = self.model.scheduler.config
sd_sampler = config.sd_sampler
if config.sd_lcm_lora: