update
This commit is contained in:
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.schema import Config, ModelType
|
||||
|
||||
|
||||
class SD(DiffusionInpaintModel):
|
||||
@@ -36,7 +36,12 @@ class SD(DiffusionInpaintModel):
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
if os.path.isfile(self.model_id_or_path):
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user