This commit is contained in:
Qing
2023-12-11 22:28:07 +08:00
parent fecf4beef0
commit 354a1280a4
13 changed files with 531 additions and 747 deletions

View File

@@ -36,16 +36,15 @@ class ModelManager:
return ControlNet(device, **{**kwargs, "model_info": model_info})
else:
if model_info.model_type in [
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
]:
raise NotImplementedError(
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
)
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
return SD(device, model_id_or_path=model_info.path, **kwargs)
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
if model_info.model_type in [
ModelType.DIFFUSERS_SDXL_INPAINT,
ModelType.DIFFUSERS_SDXL,
]:
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
raise NotImplementedError(f"Unsupported model: {name}")
@@ -88,7 +87,7 @@ class ModelManager:
if self.kwargs["sd_controlnet_method"] == control_method:
return
if not self.available_models[self.name].support_controlnet():
if not self.available_models[self.name].support_controlnet:
return
del self.model
@@ -105,7 +104,7 @@ class ModelManager:
if str(self.model.device) == "mps":
return
if self.available_models[self.name].support_freeu():
if self.available_models[self.name].support_freeu:
if config.sd_freeu:
freeu_config = config.sd_freeu_config
self.model.model.enable_freeu(
@@ -118,7 +117,7 @@ class ModelManager:
self.model.model.disable_freeu()
def enable_disable_lcm_lora(self, config: Config):
if self.available_models[self.name].support_lcm_lora():
if self.available_models[self.name].support_lcm_lora:
if config.sd_lcm_lora:
if not self.model.model.pipe.get_list_adapters():
self.model.model.load_lora_weights(self.model.lcm_lora_id)

View File

@@ -1,8 +1,14 @@
from typing import Optional
from typing import Optional, List
from enum import Enum
from PIL.Image import Image
from pydantic import BaseModel
from pydantic import BaseModel, computed_field
from lama_cleaner.const import (
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
SD_CONTROLNET_CHOICES,
)
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
@@ -31,6 +37,36 @@ class ModelInfo(BaseModel):
model_type: ModelType
is_single_file_diffusers: bool = False
@computed_field
@property
def need_prompt(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
] or self.name in [
"timbrooks/instruct-pix2pix",
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
]
@computed_field
@property
def controlnets(self) -> List[str]:
if self.model_type in [
ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT,
]:
return SDXL_CONTROLNET_CHOICES
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
return SD2_CONTROLNET_CHOICES
else:
return SD_CONTROLNET_CHOICES
return []
@computed_field
@property
def support_lcm_lora(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
@@ -39,6 +75,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_controlnet(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
@@ -47,6 +85,8 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SDXL_INPAINT,
]
@computed_field
@property
def support_freeu(self) -> bool:
return (
self.model_type
@@ -56,7 +96,7 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SD_INPAINT,
ModelType.DIFFUSERS_SDXL_INPAINT,
]
or "instruct-pix2pix" in self.name
or "timbrooks/instruct-pix2pix" in self.name
)

View File

@@ -419,14 +419,8 @@ def run_plugin():
@app.route("/server_config", methods=["GET"])
def get_server_config():
controlnet = {
"SD": SD_CONTROLNET_CHOICES,
"SD2": SD2_CONTROLNET_CHOICES,
"SDXL": SDXL_CONTROLNET_CHOICES,
}
return {
"plugins": list(plugins.keys()),
"availableControlNet": controlnet,
"enableFileManager": enable_file_manager,
"enableAutoSaving": enable_auto_saving,
}, 200
@@ -434,20 +428,12 @@ def get_server_config():
@app.route("/models", methods=["GET"])
def get_models():
return [
{
**it.dict(),
"support_lcm_lora": it.support_lcm_lora(),
"support_controlnet": it.support_controlnet(),
"support_freeu": it.support_freeu(),
}
for it in model.scan_models()
]
return [it.model_dump() for it in model.scan_models()]
@app.route("/model")
def current_model():
return model.available_models[model.name].dict(), 200
return model.available_models[model.name].model_dump(), 200
@app.route("/is_desktop")
@@ -600,8 +586,20 @@ def main(args):
else:
input_image_path = args.input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
model = ModelManager(
name=args.model,
name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method,
device=device,