fix torch_dtype & cpu_text_encoder

This commit is contained in:
Qing
2024-02-10 10:59:06 +08:00
parent 180b7d6c70
commit 9aa5a7e0ba
4 changed files with 13 additions and 4 deletions

View File

@@ -50,7 +50,7 @@ class SD(DiffusionInpaintModel):
self.model = StableDiffusionInpaintPipeline.from_single_file(
self.model_id_or_path,
dtype=torch_dtype,
torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker,
config_files=get_config_files(),
**model_kwargs,
@@ -60,7 +60,7 @@ class SD(DiffusionInpaintModel):
StableDiffusionInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
variant="fp16",
dtype=torch_dtype,
torch_dtype=torch_dtype,
**model_kwargs,
)