add switch interactiveSegModel
This commit is contained in:
@@ -43,7 +43,7 @@ from iopaint.helper import (
|
||||
)
|
||||
from iopaint.model.utils import torch_gc
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.plugins import build_plugins, RealESRGANUpscaler
|
||||
from iopaint.plugins import build_plugins, RealESRGANUpscaler, InteractiveSeg
|
||||
from iopaint.plugins.base_plugin import BasePlugin
|
||||
from iopaint.plugins.remove_bg import RemoveBG
|
||||
from iopaint.schema import (
|
||||
@@ -59,6 +59,7 @@ from iopaint.schema import (
|
||||
RemoveBGModel,
|
||||
SwitchPluginModelRequest,
|
||||
ModelInfo,
|
||||
InteractiveSegModel,
|
||||
RealESRGANModel,
|
||||
)
|
||||
|
||||
@@ -202,6 +203,9 @@ class Api:
|
||||
self.config.remove_bg_model = req.model_name
|
||||
if req.plugin_name == RealESRGANUpscaler.name:
|
||||
self.config.realesrgan_model = req.model_name
|
||||
if req.plugin_name == InteractiveSeg.name:
|
||||
self.config.interactive_seg_model = req.model_name
|
||||
torch_gc()
|
||||
|
||||
def api_server_config(self) -> ServerConfigResponse:
|
||||
plugins = []
|
||||
@@ -221,6 +225,8 @@ class Api:
|
||||
removeBGModels=RemoveBGModel.values(),
|
||||
realesrganModel=self.config.realesrgan_model,
|
||||
realesrganModels=RealESRGANModel.values(),
|
||||
interactiveSegModel=self.config.interactive_seg_model,
|
||||
interactiveSegModels=InteractiveSegModel.values(),
|
||||
enableFileManager=self.file_manager is not None,
|
||||
enableAutoSaving=self.config.output_dir is not None,
|
||||
enableControlnet=self.model_manager.enable_controlnet,
|
||||
@@ -388,38 +394,3 @@ class Api:
|
||||
cpu_offload=self.config.cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from iopaint.schema import InteractiveSegModel, RealESRGANModel
|
||||
|
||||
app = FastAPI()
|
||||
api = Api(
|
||||
app,
|
||||
ApiConfig(
|
||||
host="127.0.0.1",
|
||||
port=8080,
|
||||
model="lama",
|
||||
no_half=False,
|
||||
cpu_offload=False,
|
||||
disable_nsfw_checker=False,
|
||||
cpu_textencoder=False,
|
||||
device="cpu",
|
||||
input="/Users/cwq/code/github/MI-GAN/examples/places2_512_object/images",
|
||||
output_dir="/Users/cwq/code/github/lama-cleaner/tmp",
|
||||
quality=100,
|
||||
enable_interactive_seg=False,
|
||||
interactive_seg_model=InteractiveSegModel.vit_b,
|
||||
interactive_seg_device="cpu",
|
||||
enable_remove_bg=False,
|
||||
enable_anime_seg=False,
|
||||
enable_realesrgan=False,
|
||||
realesrgan_device="cpu",
|
||||
realesrgan_model=RealESRGANModel.realesr_general_x4v3,
|
||||
enable_gfpgan=False,
|
||||
gfpgan_device="cpu",
|
||||
enable_restoreformer=False,
|
||||
restoreformer_device="cpu",
|
||||
),
|
||||
)
|
||||
api.launch()
|
||||
|
||||
@@ -37,16 +37,31 @@ class InteractiveSeg(BasePlugin):
|
||||
|
||||
def __init__(self, model_name, device):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self._init_session(model_name)
|
||||
|
||||
def _init_session(self, model_name: str):
|
||||
model_path = download_model(
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["url"],
|
||||
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
|
||||
)
|
||||
logger.info(f"SegmentAnything model path: {model_path}")
|
||||
self.predictor = SamPredictor(
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(device)
|
||||
sam_model_registry[model_name](checkpoint=model_path).to(self.device)
|
||||
)
|
||||
self.prev_img_md5 = None
|
||||
|
||||
def switch_model(self, new_model_name):
|
||||
if self.model_name == new_model_name:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
|
||||
)
|
||||
self._init_session(new_model_name)
|
||||
self.model_name = new_model_name
|
||||
|
||||
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
|
||||
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
|
||||
return self.forward(rgb_np_img, req.clicks, img_md5)
|
||||
|
||||
@@ -427,6 +427,8 @@ class ServerConfigResponse(BaseModel):
|
||||
removeBGModels: List[RemoveBGModel]
|
||||
realesrganModel: RealESRGANModel
|
||||
realesrganModels: List[RealESRGANModel]
|
||||
interactiveSegModel: InteractiveSegModel
|
||||
interactiveSegModels: List[InteractiveSegModel]
|
||||
enableFileManager: bool
|
||||
enableAutoSaving: bool
|
||||
enableControlnet: bool
|
||||
|
||||
Reference in New Issue
Block a user