wip
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
@@ -13,26 +16,31 @@ class SDXL(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers.pipelines import AutoPipelineForInpainting
|
||||
from diffusers.pipelines import StableDiffusionXLInpaintPipeline
|
||||
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
|
||||
model_kwargs = {
|
||||
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||
}
|
||||
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
revision="main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs,
|
||||
)
|
||||
if os.path.isfile(self.model_id_or_path):
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||
self.model_id_or_path, torch_dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
||||
)
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
Reference in New Issue
Block a user