diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py index 45f468f..882aa42 100644 --- a/lama_cleaner/model/base.py +++ b/lama_cleaner/model/base.py @@ -56,17 +56,15 @@ class InpaintModel: result = self.forward(pad_image, pad_mask, config) result = result[0:origin_height, 0:origin_width, :] - if config.sd_match_histograms: - result = self._match_histograms(result, image[:, :, ::-1], mask) - - if config.sd_mask_blur != 0: - k = 2 * config.sd_mask_blur + 1 - mask = cv2.GaussianBlur(mask, (k, k), 0) + result, image, mask = self.forward_post_process(result, image, mask, config) mask = mask[:, :, np.newaxis] result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255)) return result + def forward_post_process(self, result, image, mask, config): + return result, image, mask + @torch.no_grad() def __call__(self, image, mask, config: Config): """ @@ -179,7 +177,7 @@ class InpaintModel: cdf = histogram.cumsum() normalized_cdf = cdf / float(cdf.max()) return normalized_cdf - + def _calculate_lookup(self, source_cdf, reference_cdf): lookup_table = np.zeros(256) lookup_val = 0 @@ -190,27 +188,27 @@ class InpaintModel: break lookup_table[source_index] = lookup_val return lookup_table - + def _match_histograms(self, source, reference, mask): transformed_channels = [] for channel in range(source.shape[-1]): source_channel = source[:, :, channel] reference_channel = reference[:, :, channel] - + # only calculate histograms for non-masked parts - source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0,256]) - reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0,256]) - + source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256]) + reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256]) + source_cdf = self._calculate_cdf(source_histogram) reference_cdf = self._calculate_cdf(reference_histogram) - + lookup = self._calculate_lookup(source_cdf, reference_cdf) - + transformed_channels.append(cv2.LUT(source_channel, lookup)) - + result = cv2.merge(transformed_channels) result = cv2.convertScaleAbs(result) - + return result def _run_box(self, image, mask, box, config: Config): diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 8eaf2c5..1b0abba 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -200,17 +200,21 @@ class SD(InpaintModel): return inpaint_result + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + @staticmethod def is_downloaded() -> bool: # model will be downloaded when app start, and can't switch in frontend settings return True -class SD14(SD): - model_id_or_path = "CompVis/stable-diffusion-v1-4" - image_key = "init_image" - - class SD15(SD): model_id_or_path = "runwayml/stable-diffusion-inpainting" image_key = "image" diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 267a60d..6abc261 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -109,7 +109,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): sd_steps=sd_steps, prompt='Face of a fox, high resolution, sitting on a park bench', negative_prompt='orange, yellow, small', - sd_sampler=sampler + sd_sampler=sampler, + sd_match_histograms=True ) name = f"{sampler}_negative_prompt"