Match stable diffusion result's histogram to image's

This commit is contained in:
Anders Haglund
2022-11-23 14:50:58 -08:00
parent 0b00fffe13
commit 8e408640a4
6 changed files with 62 additions and 0 deletions

View File

@@ -56,6 +56,9 @@ class InpaintModel:
result = self.forward(pad_image, pad_mask, config)
result = result[0:origin_height, 0:origin_width, :]
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
@@ -172,6 +175,44 @@ class InpaintModel:
return crop_img, crop_mask, [l, t, r, b]
def _calculate_cdf(self, histogram):
cdf = histogram.cumsum()
normalized_cdf = cdf / float(cdf.max())
return normalized_cdf
def _calculate_lookup(self, source_cdf, reference_cdf):
lookup_table = np.zeros(256)
lookup_val = 0
for source_index, source_val in enumerate(source_cdf):
for reference_index, reference_val in enumerate(reference_cdf):
if reference_val >= source_val:
lookup_val = reference_index
break
lookup_table[source_index] = lookup_val
return lookup_table
def _match_histograms(self, source, reference, mask):
transformed_channels = []
for channel in range(source.shape[-1]):
source_channel = source[:, :, channel]
reference_channel = reference[:, :, channel]
# only calculate histograms for non-masked parts
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0,256])
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0,256])
source_cdf = self._calculate_cdf(source_histogram)
reference_cdf = self._calculate_cdf(reference_histogram)
lookup = self._calculate_lookup(source_cdf, reference_cdf)
transformed_channels.append(cv2.LUT(source_channel, lookup))
result = cv2.merge(transformed_channels)
result = cv2.convertScaleAbs(result)
return result
def _run_box(self, image, mask, box, config: Config):
"""