This commit is contained in:
Qing
2023-12-16 13:34:56 +08:00
parent cbe6577890
commit 24e95daac1
16 changed files with 381 additions and 323 deletions

View File

@@ -66,9 +66,9 @@ class InstructPix2Pix(DiffusionInpaintModel):
image=PIL.Image.fromarray(image),
prompt=config.prompt,
negative_prompt=config.negative_prompt,
num_inference_steps=config.p2p_steps,
num_inference_steps=config.sd_steps,
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
guidance_scale=config.sd_guidance_scale,
output_type="np",
generator=torch.manual_seed(config.sd_seed),
).images[0]

View File

@@ -19,7 +19,7 @@ class PaintByExample(DiffusionInpaintModel):
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
@@ -58,24 +58,17 @@ class PaintByExample(DiffusionInpaintModel):
image=PIL.Image.fromarray(image),
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
num_inference_steps=config.paint_by_example_steps,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
output_type="np.array",
generator=torch.manual_seed(config.paint_by_example_seed),
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def forward_post_process(self, result, image, mask, config):
if config.paint_by_example_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.paint_by_example_mask_blur != 0:
k = 2 * config.paint_by_example_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings

View File

@@ -1,3 +1,4 @@
import gc
import math
import random
from typing import Any
@@ -913,6 +914,7 @@ def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def set_seed(seed: int):