This commit is contained in:
Qing
2023-12-15 12:40:29 +08:00
parent 142aa64cc6
commit cbe6577890
9 changed files with 35 additions and 16 deletions

View File

@@ -27,25 +27,26 @@ class ModelManager:
if name not in self.available_models:
raise NotImplementedError(f"Unsupported model: {name}")
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
model_info = self.available_models[name]
kwargs = {**kwargs, "model_info": model_info}
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
return models[name](device, **kwargs)
if sd_controlnet_enabled:
return ControlNet(device, **{**kwargs, "model_info": model_info})
return ControlNet(device, **kwargs)
else:
if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD,
]:
return SD(device, model_id_or_path=model_info.path, **kwargs)
return SD(device, **kwargs)
if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL,
]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
return SDXL(device, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}")