This commit is contained in:
Qing
2023-12-15 12:40:29 +08:00
parent 142aa64cc6
commit cbe6577890
9 changed files with 35 additions and 16 deletions

View File

@@ -8,7 +8,7 @@ from diffusers import AutoencoderKL
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.schema import Config
from lama_cleaner.schema import Config, ModelType
class SDXL(DiffusionInpaintModel):
@@ -26,9 +26,16 @@ class SDXL(DiffusionInpaintModel):
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
if self.model_info.model_type == ModelType.DIFFUSERS_SDXL:
num_in_channels = 4
else:
num_in_channels = 9
if os.path.isfile(self.model_id_or_path):
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
self.model_id_or_path, torch_dtype=torch_dtype
self.model_id_or_path,
torch_dtype=torch_dtype,
num_in_channels=num_in_channels,
)
else:
vae = AutoencoderKL.from_pretrained(