From b7699a0f261c0c6dcb60c5ceadac61cb65cce298 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 23 Nov 2024 15:51:05 +0800 Subject: [PATCH] add remove_bg_device --- iopaint/api.py | 1 + iopaint/cli.py | 10 +++++++++- iopaint/const.py | 3 ++- iopaint/plugins/__init__.py | 3 ++- iopaint/plugins/briarmbg.py | 3 ++- iopaint/plugins/briarmbg2.py | 3 ++- iopaint/plugins/remove_bg.py | 36 +++++++++++++++++++++++++++-------- iopaint/schema.py | 1 + iopaint/tests/test_plugins.py | 16 +++++++++------- iopaint/web_config.py | 8 ++++++++ 10 files changed, 64 insertions(+), 20 deletions(-) diff --git a/iopaint/api.py b/iopaint/api.py index 1c4a73f..29b4f92 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -388,6 +388,7 @@ class Api: self.config.interactive_seg_model, self.config.interactive_seg_device, self.config.enable_remove_bg, + self.config.remove_bg_device, self.config.remove_bg_model, self.config.enable_anime_seg, self.config.enable_realesrgan, diff --git a/iopaint/cli.py b/iopaint/cli.py index ae0b0c4..95d7b25 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -133,6 +133,7 @@ def start( ), interactive_seg_device: Device = Option(Device.cpu), enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP), + remove_bg_device: Device = Option(Device.cpu, help=REMOVE_BG_DEVICE_HELP), remove_bg_model: RemoveBGModel = Option(RemoveBGModel.briaai_rmbg_1_4), enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP), enable_realesrgan: bool = Option(False), @@ -145,6 +146,10 @@ def start( ): dump_environment_info() device = check_device(device) + remove_bg_device = check_device(remove_bg_device) + realesrgan_device = check_device(realesrgan_device) + gfpgan_device = check_device(gfpgan_device) + if input and not input.exists(): logger.error(f"invalid --input: {input} not exists") exit(-1) @@ -152,7 +157,9 @@ def start( logger.error(f"invalid --mask-dir: {mask_dir} not exists") exit(-1) if input and input.is_dir() and not output_dir: - logger.error("invalid --output-dir: --output-dir must be set when --input is a directory") + logger.error( + "invalid --output-dir: --output-dir must be set when --input is a directory" + ) exit(-1) if output_dir: output_dir = output_dir.expanduser().absolute() @@ -207,6 +214,7 @@ def start( interactive_seg_model=interactive_seg_model, interactive_seg_device=interactive_seg_device, enable_remove_bg=enable_remove_bg, + remove_bg_device=remove_bg_device, remove_bg_model=remove_bg_model, enable_anime_seg=enable_anime_seg, enable_realesrgan=enable_realesrgan, diff --git a/iopaint/const.py b/iopaint/const.py index b18254b..5d272b8 100644 --- a/iopaint/const.py +++ b/iopaint/const.py @@ -118,7 +118,8 @@ Quality of image encoding, 0-100. Default is 95, higher quality will generate la INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything." INTERACTIVE_SEG_MODEL_HELP = "Model size: mobile_sam < vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed." -REMOVE_BG_HELP = "Enable remove background plugin. Always run on CPU" +REMOVE_BG_HELP = "Enable remove background plugin." +REMOVE_BG_DEVICE_HELP = "Device for remove background plugin. 'cuda' only supports briaai models(briaai/RMBG-1.4 and briaai/RMBG-2.0)" ANIMESEG_HELP = "Enable anime segmentation plugin. Always run on CPU" REALESRGAN_HELP = "Enable realesrgan super resolution" GFPGAN_HELP = "Enable GFPGAN face restore. To also enhance background, use with --enable-realesrgan" diff --git a/iopaint/plugins/__init__.py b/iopaint/plugins/__init__.py index 8128025..379111b 100644 --- a/iopaint/plugins/__init__.py +++ b/iopaint/plugins/__init__.py @@ -16,6 +16,7 @@ def build_plugins( interactive_seg_model: InteractiveSegModel, interactive_seg_device: Device, enable_remove_bg: bool, + remove_bg_device: Device, remove_bg_model: str, enable_anime_seg: bool, enable_realesrgan: bool, @@ -36,7 +37,7 @@ def build_plugins( if enable_remove_bg: logger.info(f"Initialize {RemoveBG.name} plugin") - plugins[RemoveBG.name] = RemoveBG(remove_bg_model) + plugins[RemoveBG.name] = RemoveBG(remove_bg_model, remove_bg_device) if enable_anime_seg: logger.info(f"Initialize {AnimeSeg.name} plugin") diff --git a/iopaint/plugins/briarmbg.py b/iopaint/plugins/briarmbg.py index 880f530..77928cb 100644 --- a/iopaint/plugins/briarmbg.py +++ b/iopaint/plugins/briarmbg.py @@ -480,7 +480,7 @@ def create_briarmbg_session(): return net -def briarmbg_process(bgr_np_image, session, only_mask=False): +def briarmbg_process(device, bgr_np_image, session, only_mask=False): # prepare input orig_bgr_image = Image.fromarray(bgr_np_image) w, h = orig_im_size = orig_bgr_image.size @@ -490,6 +490,7 @@ def briarmbg_process(bgr_np_image, session, only_mask=False): im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = torch.divide(im_tensor, 255.0) im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) + im_tensor = im_tensor.to(device) # inference result = session(im_tensor) # post process diff --git a/iopaint/plugins/briarmbg2.py b/iopaint/plugins/briarmbg2.py index 0b3eaef..4d037f7 100644 --- a/iopaint/plugins/briarmbg2.py +++ b/iopaint/plugins/briarmbg2.py @@ -10,7 +10,7 @@ def create_briarmbg2_session(): return birefnet -def briarmbg2_process(bgr_np_image, session, only_mask=False): +def briarmbg2_process(device, bgr_np_image, session, only_mask=False): from torchvision import transforms from PIL import Image @@ -25,6 +25,7 @@ def briarmbg2_process(bgr_np_image, session, only_mask=False): image = Image.fromarray(bgr_np_image) image_size = image.size input_images = transform_image(image).unsqueeze(0) + input_images = input_images.to(device) # Prediction preds = session(input_images)[-1].sigmoid().cpu() diff --git a/iopaint/plugins/remove_bg.py b/iopaint/plugins/remove_bg.py index 18a60f6..bd03575 100644 --- a/iopaint/plugins/remove_bg.py +++ b/iopaint/plugins/remove_bg.py @@ -6,7 +6,13 @@ import torch from torch.hub import get_dir from iopaint.plugins.base_plugin import BasePlugin -from iopaint.schema import RunPluginRequest, RemoveBGModel +from iopaint.schema import Device, RunPluginRequest, RemoveBGModel + + +def _rmbg_remove(device, *args, **kwargs): + from rembg import remove + + return remove(*args, **kwargs) class RemoveBG(BasePlugin): @@ -14,9 +20,10 @@ class RemoveBG(BasePlugin): support_gen_mask = True support_gen_image = True - def __init__(self, model_name): + def __init__(self, model_name, device): super().__init__() self.model_name = model_name + self.device = device if model_name.startswith("birefnet"): import rembg @@ -33,13 +40,15 @@ class RemoveBG(BasePlugin): self._init_session(model_name) def _init_session(self, model_name: str): + self.device_warning() + if model_name == RemoveBGModel.briaai_rmbg_1_4: from iopaint.plugins.briarmbg import ( create_briarmbg_session, briarmbg_process, ) - self.session = create_briarmbg_session() + self.session = create_briarmbg_session().to(self.device) self.remove = briarmbg_process elif model_name == RemoveBGModel.briaai_rmbg_2_0: from iopaint.plugins.briarmbg2 import ( @@ -47,13 +56,13 @@ class RemoveBG(BasePlugin): briarmbg2_process, ) - self.session = create_briarmbg2_session() + self.session = create_briarmbg2_session().to(self.device) self.remove = briarmbg2_process else: - from rembg import new_session, remove + from rembg import new_session self.session = new_session(model_name=model_name) - self.remove = remove + self.remove = _rmbg_remove def switch_model(self, new_model_name): if self.model_name == new_model_name: @@ -70,7 +79,7 @@ class RemoveBG(BasePlugin): bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) # return BGRA image - output = self.remove(bgr_np_img, session=self.session) + output = self.remove(self.device, bgr_np_img, session=self.session) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) @torch.inference_mode() @@ -78,7 +87,9 @@ class RemoveBG(BasePlugin): bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) # return BGR image, 255 means foreground, 0 means background - output = self.remove(bgr_np_img, session=self.session, only_mask=True) + output = self.remove( + self.device, bgr_np_img, session=self.session, only_mask=True + ) return output def check_dep(self): @@ -86,3 +97,12 @@ class RemoveBG(BasePlugin): import rembg except ImportError: return "RemoveBG is not installed, please install it first. pip install -U rembg" + + def device_warning(self): + if self.device == Device.cuda and self.model_name not in [ + RemoveBGModel.briaai_rmbg_1_4, + RemoveBGModel.briaai_rmbg_2_0, + ]: + logger.warning( + f"remove_bg_device=cuda only supports briaai models({RemoveBGModel.briaai_rmbg_1_4.value}/{RemoveBGModel.briaai_rmbg_2_0.value})" + ) diff --git a/iopaint/schema.py b/iopaint/schema.py index 655119a..f363a81 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -266,6 +266,7 @@ class ApiConfig(BaseModel): interactive_seg_model: InteractiveSegModel interactive_seg_device: Device enable_remove_bg: bool + remove_bg_device: Device remove_bg_model: str enable_anime_seg: bool enable_realesrgan: bool diff --git a/iopaint/tests/test_plugins.py b/iopaint/tests/test_plugins.py index 7857e9e..3970bb3 100644 --- a/iopaint/tests/test_plugins.py +++ b/iopaint/tests/test_plugins.py @@ -3,7 +3,7 @@ from PIL import Image from iopaint.helper import encode_pil_to_base64, gen_frontend_mask from iopaint.plugins.anime_seg import AnimeSeg -from iopaint.schema import RunPluginRequest, RemoveBGModel, InteractiveSegModel +from iopaint.schema import Device, RunPluginRequest, RemoveBGModel, InteractiveSegModel from iopaint.tests.utils import check_device, current_dir, save_dir os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -38,24 +38,26 @@ def _save(img, name): @pytest.mark.parametrize("model_name", RemoveBGModel.values()) -def test_remove_bg(model_name): - print(f"Testing {model_name}") - model = RemoveBG(model_name) +@pytest.mark.parametrize("device", Device.values()) +def test_remove_bg(model_name, device): + check_device(device) + print(f"Testing {model_name} on {device}") + model = RemoveBG(model_name, device) rgba_np_img = model.gen_image( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) res = cv2.cvtColor(rgba_np_img, cv2.COLOR_RGBA2BGRA) - _save(res, f"test_remove_bg_{model_name}.png") + _save(res, f"test_remove_bg_{model_name}_{device}.png") bgr_np_img = model.gen_mask( rgb_img, RunPluginRequest(name=RemoveBG.name, image=rgb_img_base64) ) res_mask = gen_frontend_mask(bgr_np_img) - _save(res_mask, f"test_remove_bg_frontend_mask_{model_name}.png") + _save(res_mask, f"test_remove_bg_frontend_mask_{model_name}_{device}.png") assert len(bgr_np_img.shape) == 2 - _save(bgr_np_img, f"test_remove_bg_mask_{model_name}.jpeg") + _save(bgr_np_img, f"test_remove_bg_mask_{model_name}_{device}.jpeg") def test_anime_seg(): diff --git a/iopaint/web_config.py b/iopaint/web_config.py index 3b3f8d1..c08ccad 100644 --- a/iopaint/web_config.py +++ b/iopaint/web_config.py @@ -51,6 +51,7 @@ default_configs = dict( interactive_seg_model=InteractiveSegModel.sam2_1_tiny, interactive_seg_device=Device.cpu, enable_remove_bg=False, + remove_bg_device=Device.cpu, remove_bg_model=RemoveBGModel.briaai_rmbg_1_4, enable_anime_seg=False, enable_realesrgan=False, @@ -99,6 +100,7 @@ def save_config( interactive_seg_model, interactive_seg_device, enable_remove_bg, + remove_bg_device, remove_bg_model, enable_anime_seg, enable_realesrgan, @@ -236,6 +238,11 @@ def main(config_file: Path): enable_remove_bg = gr.Checkbox( init_config.enable_remove_bg, label=REMOVE_BG_HELP ) + remove_bg_device = gr.Radio( + Device.values(), + label=REMOVE_BG_DEVICE_HELP, + value=init_config.remove_bg_device, + ) remove_bg_model = gr.Radio( RemoveBGModel.values(), label="Remove bg model", @@ -304,6 +311,7 @@ def main(config_file: Path): interactive_seg_model, interactive_seg_device, enable_remove_bg, + remove_bg_device, remove_bg_model, enable_anime_seg, enable_realesrgan,