add forward_post_process function
This commit is contained in:
@@ -56,17 +56,15 @@ class InpaintModel:
|
|||||||
result = self.forward(pad_image, pad_mask, config)
|
result = self.forward(pad_image, pad_mask, config)
|
||||||
result = result[0:origin_height, 0:origin_width, :]
|
result = result[0:origin_height, 0:origin_width, :]
|
||||||
|
|
||||||
if config.sd_match_histograms:
|
result, image, mask = self.forward_post_process(result, image, mask, config)
|
||||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
|
||||||
k = 2 * config.sd_mask_blur + 1
|
|
||||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
|
||||||
|
|
||||||
mask = mask[:, :, np.newaxis]
|
mask = mask[:, :, np.newaxis]
|
||||||
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def forward_post_process(self, result, image, mask, config):
|
||||||
|
return result, image, mask
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask, config: Config):
|
def __call__(self, image, mask, config: Config):
|
||||||
"""
|
"""
|
||||||
@@ -179,7 +177,7 @@ class InpaintModel:
|
|||||||
cdf = histogram.cumsum()
|
cdf = histogram.cumsum()
|
||||||
normalized_cdf = cdf / float(cdf.max())
|
normalized_cdf = cdf / float(cdf.max())
|
||||||
return normalized_cdf
|
return normalized_cdf
|
||||||
|
|
||||||
def _calculate_lookup(self, source_cdf, reference_cdf):
|
def _calculate_lookup(self, source_cdf, reference_cdf):
|
||||||
lookup_table = np.zeros(256)
|
lookup_table = np.zeros(256)
|
||||||
lookup_val = 0
|
lookup_val = 0
|
||||||
@@ -190,27 +188,27 @@ class InpaintModel:
|
|||||||
break
|
break
|
||||||
lookup_table[source_index] = lookup_val
|
lookup_table[source_index] = lookup_val
|
||||||
return lookup_table
|
return lookup_table
|
||||||
|
|
||||||
def _match_histograms(self, source, reference, mask):
|
def _match_histograms(self, source, reference, mask):
|
||||||
transformed_channels = []
|
transformed_channels = []
|
||||||
for channel in range(source.shape[-1]):
|
for channel in range(source.shape[-1]):
|
||||||
source_channel = source[:, :, channel]
|
source_channel = source[:, :, channel]
|
||||||
reference_channel = reference[:, :, channel]
|
reference_channel = reference[:, :, channel]
|
||||||
|
|
||||||
# only calculate histograms for non-masked parts
|
# only calculate histograms for non-masked parts
|
||||||
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0,256])
|
source_histogram, _ = np.histogram(source_channel[mask == 0], 256, [0, 256])
|
||||||
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0,256])
|
reference_histogram, _ = np.histogram(reference_channel[mask == 0], 256, [0, 256])
|
||||||
|
|
||||||
source_cdf = self._calculate_cdf(source_histogram)
|
source_cdf = self._calculate_cdf(source_histogram)
|
||||||
reference_cdf = self._calculate_cdf(reference_histogram)
|
reference_cdf = self._calculate_cdf(reference_histogram)
|
||||||
|
|
||||||
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
lookup = self._calculate_lookup(source_cdf, reference_cdf)
|
||||||
|
|
||||||
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
transformed_channels.append(cv2.LUT(source_channel, lookup))
|
||||||
|
|
||||||
result = cv2.merge(transformed_channels)
|
result = cv2.merge(transformed_channels)
|
||||||
result = cv2.convertScaleAbs(result)
|
result = cv2.convertScaleAbs(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _run_box(self, image, mask, box, config: Config):
|
def _run_box(self, image, mask, box, config: Config):
|
||||||
|
|||||||
@@ -200,17 +200,21 @@ class SD(InpaintModel):
|
|||||||
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
|
def forward_post_process(self, result, image, mask, config):
|
||||||
|
if config.sd_match_histograms:
|
||||||
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
|
if config.sd_mask_blur != 0:
|
||||||
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
|
return result, image, mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_downloaded() -> bool:
|
def is_downloaded() -> bool:
|
||||||
# model will be downloaded when app start, and can't switch in frontend settings
|
# model will be downloaded when app start, and can't switch in frontend settings
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SD14(SD):
|
|
||||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
|
||||||
image_key = "init_image"
|
|
||||||
|
|
||||||
|
|
||||||
class SD15(SD):
|
class SD15(SD):
|
||||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||||
image_key = "image"
|
image_key = "image"
|
||||||
|
|||||||
@@ -109,7 +109,8 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
|||||||
sd_steps=sd_steps,
|
sd_steps=sd_steps,
|
||||||
prompt='Face of a fox, high resolution, sitting on a park bench',
|
prompt='Face of a fox, high resolution, sitting on a park bench',
|
||||||
negative_prompt='orange, yellow, small',
|
negative_prompt='orange, yellow, small',
|
||||||
sd_sampler=sampler
|
sd_sampler=sampler,
|
||||||
|
sd_match_histograms=True
|
||||||
)
|
)
|
||||||
|
|
||||||
name = f"{sampler}_negative_prompt"
|
name = f"{sampler}_negative_prompt"
|
||||||
|
|||||||
Reference in New Issue
Block a user