This commit is contained in:
Qing
2022-07-14 16:49:03 +08:00
parent 0f70ab58a7
commit a94f7e4ffe
16 changed files with 487 additions and 45 deletions

View File

@@ -1,4 +1,5 @@
import abc
from typing import Optional
import cv2
import torch
@@ -9,7 +10,9 @@ from lama_cleaner.schema import Config, HDStrategy
class InpaintModel:
min_size: Optional[int] = None
pad_mod = 8
pad_to_square = False
def __init__(self, device):
"""
@@ -31,18 +34,21 @@ class InpaintModel:
@abc.abstractmethod
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W]
"""Input images and output images have same size
images: [H, W, C] RGB
masks: [H, W] 255 为 masks 区域
return: BGR IMAGE
"""
...
def _pad_forward(self, image, mask, config: Config):
origin_height, origin_width = image.shape[:2]
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
result = self.forward(padd_image, padd_mask, config)
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
logger.info(f"final forward pad size: {pad_image.shape}")
result = self.forward(pad_image, pad_mask, config)
result = result[0:origin_height, 0:origin_width, :]
original_pixel_indices = mask != 255
@@ -52,8 +58,8 @@ class InpaintModel:
@torch.no_grad()
def __call__(self, image, mask, config: Config):
"""
image: [H, W, C] RGB, not normalized
mask: [H, W]
images: [H, W, C] RGB, not normalized
masks: [H, W]
return: BGR IMAGE
"""
inpaint_result = None