diff --git a/iopaint/plugins/briarmbg2.py b/iopaint/plugins/briarmbg2.py new file mode 100644 index 0000000..0b3eaef --- /dev/null +++ b/iopaint/plugins/briarmbg2.py @@ -0,0 +1,39 @@ +import numpy as np + + +def create_briarmbg2_session(): + from transformers import AutoModelForImageSegmentation + + birefnet = AutoModelForImageSegmentation.from_pretrained( + "briaai/RMBG-2.0", trust_remote_code=True + ) + return birefnet + + +def briarmbg2_process(bgr_np_image, session, only_mask=False): + from torchvision import transforms + from PIL import Image + + transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + + image = Image.fromarray(bgr_np_image) + image_size = image.size + input_images = transform_image(image).unsqueeze(0) + + # Prediction + preds = session(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + + if only_mask: + return np.array(mask) + + image.putalpha(mask) + return np.array(image) diff --git a/iopaint/plugins/remove_bg.py b/iopaint/plugins/remove_bg.py index e066662..18a60f6 100644 --- a/iopaint/plugins/remove_bg.py +++ b/iopaint/plugins/remove_bg.py @@ -2,6 +2,7 @@ import os import cv2 import numpy as np from loguru import logger +import torch from torch.hub import get_dir from iopaint.plugins.base_plugin import BasePlugin @@ -40,6 +41,14 @@ class RemoveBG(BasePlugin): self.session = create_briarmbg_session() self.remove = briarmbg_process + elif model_name == RemoveBGModel.briaai_rmbg_2_0: + from iopaint.plugins.briarmbg2 import ( + create_briarmbg2_session, + briarmbg2_process, + ) + + self.session = create_briarmbg2_session() + self.remove = briarmbg2_process else: from rembg import new_session, remove @@ -56,6 +65,7 @@ class RemoveBG(BasePlugin): self._init_session(new_model_name) self.model_name = new_model_name + @torch.inference_mode() def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) @@ -63,6 +73,7 @@ class RemoveBG(BasePlugin): output = self.remove(bgr_np_img, session=self.session) return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA) + @torch.inference_mode() def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray: bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR) diff --git a/iopaint/schema.py b/iopaint/schema.py index 5bbb896..655119a 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -145,6 +145,7 @@ class RealESRGANModel(Choices): class RemoveBGModel(Choices): briaai_rmbg_1_4 = "briaai/RMBG-1.4" + briaai_rmbg_2_0 = "briaai/RMBG-2.0" # models from https://github.com/danielgatis/rembg u2net = "u2net" u2netp = "u2netp" diff --git a/iopaint/tests/test_plugins.py b/iopaint/tests/test_plugins.py index ffa35e6..7857e9e 100644 --- a/iopaint/tests/test_plugins.py +++ b/iopaint/tests/test_plugins.py @@ -33,6 +33,7 @@ person_rgb_img = cv2.resize(person_rgb_img, (512, 512)) def _save(img, name): + name = name.replace("/", "_") cv2.imwrite(str(save_dir / name), img) diff --git a/requirements.txt b/requirements.txt index f4bca9b..9f27a10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diffusers==0.27.2 huggingface_hub==0.25.2 accelerate peft==0.7.1 -transformers>=4.35.1 +transformers>=4.39.1 safetensors controlnet-aux==0.0.3 fastapi==0.108.0