add download command
This commit is contained in:
@@ -4,7 +4,6 @@ import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import set_seed
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
@@ -15,18 +14,21 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
fp16 = not kwargs.get('no_half', False)
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
))
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix",
|
||||
@@ -36,15 +38,23 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
)
|
||||
|
||||
self.model.enable_attention_slicing()
|
||||
if kwargs.get('enable_xformers', False):
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
|
||||
StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix", revision="fp16"
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@@ -60,7 +70,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
image_guidance_scale=config.p2p_image_guidance_scale,
|
||||
guidance_scale=config.p2p_guidance_scale,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(config.sd_seed)
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
||||
Reference in New Issue
Block a user