update
This commit is contained in:
@@ -27,25 +27,26 @@ class ModelManager:
|
||||
if name not in self.available_models:
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
||||
model_info = self.available_models[name]
|
||||
kwargs = {**kwargs, "model_info": model_info}
|
||||
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
||||
return models[name](device, **kwargs)
|
||||
|
||||
if sd_controlnet_enabled:
|
||||
return ControlNet(device, **{**kwargs, "model_info": model_info})
|
||||
return ControlNet(device, **kwargs)
|
||||
else:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]:
|
||||
return SD(device, model_id_or_path=model_info.path, **kwargs)
|
||||
return SD(device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
|
||||
return SDXL(device, **kwargs)
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user