add forward_post_process function
This commit is contained in:
@@ -56,17 +56,15 @@ class InpaintModel:
|
||||
result = self.forward(pad_image, pad_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
if config.sd_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
result, image, mask = self.forward_post_process(result, image, mask, config)
|
||||
|
||||
mask = mask[:, :, np.newaxis]
|
||||
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
||||
return result
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
return result, image, mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
|
||||
@@ -200,17 +200,21 @@ class SD(InpaintModel):
|
||||
|
||||
return inpaint_result
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.sd_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
return result, image, mask
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
|
||||
class SD14(SD):
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
image_key = "init_image"
|
||||
|
||||
|
||||
class SD15(SD):
|
||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||
image_key = "image"
|
||||
|
||||
@@ -109,7 +109,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||
sd_steps=sd_steps,
|
||||
prompt='Face of a fox, high resolution, sitting on a park bench',
|
||||
negative_prompt='orange, yellow, small',
|
||||
sd_sampler=sampler
|
||||
sd_sampler=sampler,
|
||||
sd_match_histograms=True
|
||||
)
|
||||
|
||||
name = f"{sampler}_negative_prompt"
|
||||
|
||||
Reference in New Issue
Block a user