add ZITS
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@@ -9,7 +10,9 @@ from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
|
||||
def __init__(self, device):
|
||||
"""
|
||||
@@ -31,18 +34,21 @@ class InpaintModel:
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W]
|
||||
"""Input images and output images have same size
|
||||
images: [H, W, C] RGB
|
||||
masks: [H, W] 255 为 masks 区域
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
...
|
||||
|
||||
def _pad_forward(self, image, mask, config: Config):
|
||||
origin_height, origin_width = image.shape[:2]
|
||||
padd_image = pad_img_to_modulo(image, mod=self.pad_mod)
|
||||
padd_mask = pad_img_to_modulo(mask, mod=self.pad_mod)
|
||||
result = self.forward(padd_image, padd_mask, config)
|
||||
pad_image = pad_img_to_modulo(image, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
pad_mask = pad_img_to_modulo(mask, mod=self.pad_mod, square=self.pad_to_square, min_size=self.min_size)
|
||||
|
||||
logger.info(f"final forward pad size: {pad_image.shape}")
|
||||
|
||||
result = self.forward(pad_image, pad_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
original_pixel_indices = mask != 255
|
||||
@@ -52,8 +58,8 @@ class InpaintModel:
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
image: [H, W, C] RGB, not normalized
|
||||
mask: [H, W]
|
||||
images: [H, W, C] RGB, not normalized
|
||||
masks: [H, W]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
inpaint_result = None
|
||||
|
||||
Reference in New Issue
Block a user