This commit is contained in:
Qing
2023-12-11 22:28:07 +08:00
parent fecf4beef0
commit 354a1280a4
13 changed files with 531 additions and 747 deletions

View File

@@ -36,16 +36,15 @@ class ModelManager:
return ControlNet(device, **{**kwargs, "model_info": model_info})
else:
if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
]:
raise NotImplementedError(
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
)
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
return SD(device, model_id_or_path=model_info.path, **kwargs)
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL,
]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}")
@@ -88,7 +87,7 @@ class ModelManager:
if self.kwargs["sd_controlnet_method"] == control_method:
return
if not self.available_models[self.name].support_controlnet():
if not self.available_models[self.name].support_controlnet:
return
del self.model
@@ -105,7 +104,7 @@ class ModelManager:
if str(self.model.device) == "mps":
return
if self.available_models[self.name].support_freeu():
if self.available_models[self.name].support_freeu:
if config.sd_freeu:
freeu_config = config.sd_freeu_config
self.model.model.enable_freeu(
@@ -118,7 +117,7 @@ class ModelManager:
self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config):
if self.available_models[self.name].support_lcm_lora():
if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora:
if not self.model.model.pipe.get_list_adapters():
self.model.model.load_lora_weights(self.model.lcm_lora_id)