add forward_post_process function

This commit is contained in:
Qing
2022-11-25 09:29:20 +08:00
parent f7d7e89197
commit af87cca643
3 changed files with 25 additions and 22 deletions

View File

@@ -200,17 +200,21 @@ class SD(InpaintModel):
return inpaint_result
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True
class SD14(SD):
model_id_or_path = "CompVis/stable-diffusion-v1-4"
image_key = "init_image"
class SD15(SD):
model_id_or_path = "runwayml/stable-diffusion-inpainting"
image_key = "image"