add remove_bg_device

This commit is contained in:
Qing
2024-11-23 15:51:05 +08:00
parent d29fe6ecb5
commit b7699a0f26
10 changed files with 64 additions and 20 deletions

View File

@@ -10,7 +10,7 @@ def create_briarmbg2_session():
return birefnet
def briarmbg2_process(bgr_np_image, session, only_mask=False):
def briarmbg2_process(device, bgr_np_image, session, only_mask=False):
from torchvision import transforms
from PIL import Image
@@ -25,6 +25,7 @@ def briarmbg2_process(bgr_np_image, session, only_mask=False):
image = Image.fromarray(bgr_np_image)
image_size = image.size
input_images = transform_image(image).unsqueeze(0)
input_images = input_images.to(device)
# Prediction
preds = session(input_images)[-1].sigmoid().cpu()