add remove_bg_device
This commit is contained in:
@@ -10,7 +10,7 @@ def create_briarmbg2_session():
|
||||
return birefnet
|
||||
|
||||
|
||||
def briarmbg2_process(bgr_np_image, session, only_mask=False):
|
||||
def briarmbg2_process(device, bgr_np_image, session, only_mask=False):
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
|
||||
@@ -25,6 +25,7 @@ def briarmbg2_process(bgr_np_image, session, only_mask=False):
|
||||
image = Image.fromarray(bgr_np_image)
|
||||
image_size = image.size
|
||||
input_images = transform_image(image).unsqueeze(0)
|
||||
input_images = input_images.to(device)
|
||||
|
||||
# Prediction
|
||||
preds = session(input_images)[-1].sigmoid().cpu()
|
||||
|
||||
Reference in New Issue
Block a user