Qing
2024-04-12 13:09:37 +08:00
parent f71e9cfb26
commit 35f12d5b9b
4 changed files with 13 additions and 9 deletions

View File

@@ -71,6 +71,7 @@ class ControlNet(DiffusionInpaintModel):
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.torch_dtype = torch_dtype
original_config_file_name = "v1"
if model_info.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT,
@@ -78,6 +79,8 @@ class ControlNet(DiffusionInpaintModel):
from diffusers import (
StableDiffusionControlNetInpaintPipeline as PipeClass,
)
original_config_file_name = "v1"
elif model_info.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
@@ -85,6 +88,7 @@ class ControlNet(DiffusionInpaintModel):
from diffusers import (
StableDiffusionXLControlNetInpaintPipeline as PipeClass,
)
original_config_file_name = "xl"
controlnet = ControlNetModel.from_pretrained(
pretrained_model_name_or_path=controlnet_method,
@@ -103,7 +107,7 @@ class ControlNet(DiffusionInpaintModel):
controlnet=controlnet,
load_safety_checker=not disable_nsfw_checker,
torch_dtype=torch_dtype,
config_files=get_config_files(),
original_config_file=get_config_files()[original_config_file_name],
**model_kwargs,
)
else:

View File

@@ -52,7 +52,7 @@ class SD(DiffusionInpaintModel):
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker,
config_files=get_config_files(),
original_config_file=get_config_files()['v1'],
**model_kwargs,
)
else:

View File

@@ -42,7 +42,7 @@ class SDXL(DiffusionInpaintModel):
torch_dtype=torch_dtype,
num_in_channels=num_in_channels,
load_safety_checker=False,
config_files=get_config_files()
original_config_file=get_config_files()['xl'],
)
else:
model_kwargs = {