add remove bg model selection
This commit is contained in:
@@ -42,10 +42,10 @@ from iopaint.helper import (
|
||||
adjust_mask,
|
||||
)
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_info import ModelInfo
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.plugins import build_plugins
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.remove_bg import RemoveBG
|
||||
from iopaint.schema import (
|
||||
GenInfoResponse,
|
||||
ApiConfig,
|
||||
@@ -56,6 +56,9 @@ from iopaint.schema import (
|
||||
SDSampler,
|
||||
PluginInfo,
|
||||
AdjustMaskRequest,
|
||||
RemoveBGModel,
|
||||
SwitchPluginModelRequest,
|
||||
ModelInfo,
|
||||
)
|
||||
|
||||
CURRENT_DIR = Path(__file__).parent.absolute().resolve()
|
||||
@@ -154,11 +157,11 @@ class Api:
|
||||
# fmt: off
|
||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/models", self.api_models, methods=["GET"], response_model=List[ModelInfo])
|
||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
self.add_api_route("/api/v1/inpaint", self.api_inpaint, methods=["POST"])
|
||||
self.add_api_route("/api/v1/switch_plugin_model", self.api_switch_plugin_model, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_mask", self.api_run_plugin_gen_mask, methods=["POST"])
|
||||
self.add_api_route("/api/v1/run_plugin_gen_image", self.api_run_plugin_gen_image, methods=["POST"])
|
||||
self.add_api_route("/api/v1/samplers", self.api_samplers, methods=["GET"])
|
||||
@@ -175,9 +178,6 @@ class Api:
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
return self.app.add_api_route(path, endpoint, **kwargs)
|
||||
|
||||
def api_models(self) -> List[ModelInfo]:
|
||||
return self.model_manager.scan_models()
|
||||
|
||||
def api_current_model(self) -> ModelInfo:
|
||||
return self.model_manager.current_model
|
||||
|
||||
@@ -187,16 +187,28 @@ class Api:
|
||||
self.model_manager.switch(req.name)
|
||||
return self.model_manager.current_model
|
||||
|
||||
def api_switch_plugin_model(self, req: SwitchPluginModelRequest):
|
||||
if req.plugin_name in self.plugins:
|
||||
self.plugins[req.plugin_name].switch_model(req.model_name)
|
||||
if req.plugin_name == RemoveBG.name:
|
||||
self.config.remove_bg_model = req.model_name
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
return ServerConfigResponse(
|
||||
plugins=[
|
||||
plugins = []
|
||||
for it in self.plugins.values():
|
||||
plugins.append(
|
||||
PluginInfo(
|
||||
name=it.name,
|
||||
support_gen_image=it.support_gen_image,
|
||||
support_gen_mask=it.support_gen_mask,
|
||||
)
|
||||
for it in self.plugins.values()
|
||||
],
|
||||
)
|
||||
|
||||
return ServerConfigResponse(
|
||||
plugins=plugins,
|
||||
modelInfos=self.model_manager.scan_models(),
|
||||
removeBGModel=self.config.remove_bg_model,
|
||||
removeBGModels=RemoveBGModel.values(),
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
@@ -340,6 +352,7 @@ class Api:
|
||||
self.config.interactive_seg_model,
|
||||
self.config.interactive_seg_device,
|
||||
self.config.enable_remove_bg,
|
||||
self.config.remove_bg_model,
|
||||
self.config.enable_anime_seg,
|
||||
self.config.enable_realesrgan,
|
||||
self.config.realesrgan_device,
|
||||
|
||||
Reference in New Issue
Block a user