update
This commit is contained in:
@@ -66,9 +66,9 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
image=PIL.Image.fromarray(image),
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.p2p_steps,
|
||||
num_inference_steps=config.sd_steps,
|
||||
image_guidance_scale=config.p2p_image_guidance_scale,
|
||||
guidance_scale=config.p2p_guidance_scale,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
@@ -19,7 +19,7 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
|
||||
model_kwargs = {}
|
||||
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Paint By Example Model NSFW checker")
|
||||
@@ -58,24 +58,17 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
image=PIL.Image.fromarray(image),
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
example_image=config.paint_by_example_example_image,
|
||||
num_inference_steps=config.paint_by_example_steps,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
||||
output_type="np.array",
|
||||
generator=torch.manual_seed(config.paint_by_example_seed),
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.paint_by_example_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.paint_by_example_mask_blur != 0:
|
||||
k = 2 * config.paint_by_example_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
return result, image, mask
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
from typing import Any
|
||||
@@ -913,6 +914,7 @@ def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
|
||||
Reference in New Issue
Block a user