add download command

This commit is contained in:
Qing
2023-11-16 21:12:06 +08:00
parent 20e660aa4a
commit 1d145d1cd6
17 changed files with 233 additions and 67 deletions

View File

@@ -2,11 +2,9 @@ import PIL
import PIL.Image
import cv2
import torch
from diffusers import DiffusionPipeline
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import set_seed
from lama_cleaner.schema import Config
@@ -16,35 +14,40 @@ class PaintByExample(DiffusionInpaintModel):
min_size = 512
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get('no_half', False)
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
from diffusers import DiffusionPipeline
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
fp16 = not kwargs.get("no_half", False)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Paint By Example Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
requires_safety_checker=False
))
model_kwargs.update(
dict(safety_checker=None, requires_safety_checker=False)
)
self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype,
**model_kwargs
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
)
self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
# TODO: gpu_id
if kwargs.get('cpu_offload', False) and use_gpu:
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
@staticmethod
def download():
from diffusers import DiffusionPipeline
DiffusionPipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
@@ -56,8 +59,8 @@ class PaintByExample(DiffusionInpaintModel):
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
example_image=config.paint_by_example_example_image,
num_inference_steps=config.paint_by_example_steps,
output_type='np.array',
generator=torch.manual_seed(config.paint_by_example_seed)
output_type="np.array",
generator=torch.manual_seed(config.paint_by_example_seed),
).images[0]
output = (output * 255).round().astype("uint8")