add download command

This commit is contained in:
Qing
2023-11-16 21:12:06 +08:00
parent 20e660aa4a
commit 1d145d1cd6
17 changed files with 233 additions and 67 deletions

View File

@@ -4,7 +4,6 @@ import torch
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
@@ -15,18 +14,21 @@ class InstructPix2Pix(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
from diffusers import StableDiffusionInstructPix2PixPipeline
fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
fp16 = not kwargs.get("no_half", False)
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
))
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix",
@@ -36,15 +38,23 @@ class InstructPix2Pix(DiffusionInpaintModel):
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and use_gpu:
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import StableDiffusionInstructPix2PixPipeline
StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", revision="fp16"
)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@@ -60,7 +70,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
image_guidance_scale=config.p2p_image_guidance_scale,
guidance_scale=config.p2p_guidance_scale,
output_type="np",
generator=torch.manual_seed(config.sd_seed)
generator=torch.manual_seed(config.sd_seed),
).images[0]
output = (output * 255).round().astype("uint8")