add remove bg model selection

This commit is contained in:
Qing
2024-02-08 16:49:54 +08:00
parent cf9ceea4e6
commit 8060e16c70
19 changed files with 915 additions and 222 deletions

View File

@@ -42,10 +42,10 @@ from iopaint.helper import (
adjust_mask,
)
from iopaint.model.utils import torch_gc
from iopaint.model_info import ModelInfo
from iopaint.model_manager import ModelManager
from iopaint.plugins import build_plugins
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.remove_bg import RemoveBG
from iopaint.schema import (
GenInfoResponse,
ApiConfig,
@@ -56,6 +56,9 @@ from iopaint.schema import (
SDSampler,
PluginInfo,
AdjustMaskRequest,
RemoveBGModel,
SwitchPluginModelRequest,
ModelInfo,
)
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
@@ -154,11 +157,11 @@ class Api:
# fmt: off
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
@@ -175,9 +178,6 @@ class Api:
def add_api_route(self, path: str, endpoint, **kwargs):
return self.app.add_api_route(path, endpoint, **kwargs)
def api_models(self) -> List[ModelInfo]:
return self.model_manager.scan_models()
def api_current_model(self) -> ModelInfo:
return self.model_manager.current_model
@@ -187,16 +187,28 @@ class Api:
self.model_manager.switch(req.name)
return self.model_manager.current_model
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
if req.plugin_name in self.plugins:
self.plugins[req.plugin_name].switch_model(req.model_name)
if req.plugin_name == RemoveBG.name:
self.config.remove_bg_model = req.model_name
def api_server_config(self) -> ServerConfigResponse:
return ServerConfigResponse(
plugins=[
plugins = []
for it in self.plugins.values():
plugins.append(
PluginInfo(
name=it.name,
support_gen_image=it.support_gen_image,
support_gen_mask=it.support_gen_mask,
)
for it in self.plugins.values()
],
)
return ServerConfigResponse(
plugins=plugins,
modelInfos=self.model_manager.scan_models(),
removeBGModel=self.config.remove_bg_model,
removeBGModels=RemoveBGModel.values(),
enableFileManager=self.file_manager is not None,
enableAutoSaving=self.config.output_dir is not None,
enableControlnet=self.model_manager.enable_controlnet,
@@ -340,6 +352,7 @@ class Api:
self.config.interactive_seg_model,
self.config.interactive_seg_device,
self.config.enable_remove_bg,
self.config.remove_bg_model,
self.config.enable_anime_seg,
self.config.enable_realesrgan,
self.config.realesrgan_device,