add Segment Anything

This commit is contained in:
Qing
2023-04-06 21:55:20 +08:00
parent ed36744339
commit a6aec566d9
20 changed files with 1885 additions and 299 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.7 KiB

View File

@@ -1,43 +0,0 @@
from pathlib import Path
import cv2
import numpy as np
from lama_cleaner.plugins import InteractiveSeg, Click
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "overture-creations-5sI6fQgYIuo.png"
def test_interactive_seg():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)]
)
cv2.imwrite(str(save_dir / "test_interactive_seg.png"), pred)
def test_interactive_seg_with_negative_click():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
pred = interactive_seg_model.forward(
img,
clicks=[
Click(coords=(256, 256), indx=0, is_positive=True),
Click(coords=(384, 256), indx=1, is_positive=False),
],
)
cv2.imwrite(str(save_dir / "test_interactive_seg_negative.png"), pred)
def test_interactive_seg_with_prev_mask():
interactive_seg_model = InteractiveSeg()
img = cv2.imread(str(img_p))
mask = np.zeros_like(img)[:, :, 0]
pred = interactive_seg_model.forward(
img, clicks=[Click(coords=(256, 256), indx=0, is_positive=True)], prev_mask=mask
)
cv2.imwrite(str(save_dir / "test_interactive_seg_with_mask.png"), pred)

View File

@@ -1,3 +1,8 @@
import hashlib
import os
import time
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
import cv2
@@ -9,12 +14,14 @@ from lama_cleaner.plugins import (
RealESRGANUpscaler,
GFPGANPlugin,
RestoreFormerPlugin,
InteractiveSeg,
)
current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result"
save_dir.mkdir(exist_ok=True, parents=True)
img_p = current_dir / "bunny.jpeg"
img_bytes = open(img_p, "rb").read()
bgr_img = cv2.imread(str(img_p))
rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
@@ -64,3 +71,21 @@ def test_restoreformer(device):
model = RestoreFormerPlugin(device)
res = model(rgb_img, None, None)
_save(res, f"test_restoreformer_{device}.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_segment_anything(device):
if device == "cuda" and not torch.cuda.is_available():
return
if device == "mps" and not torch.backends.mps.is_available():
return
img_md5 = hashlib.md5(img_bytes).hexdigest()
model = InteractiveSeg("vit_l", device)
new_mask = model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
save_name = f"test_segment_anything_{device}.png"
_save(new_mask, save_name)
start = time.time()
model.forward(rgb_img, [[448 // 2, 394 // 2, 1]], img_md5)
print(f"Time for {save_name}: {time.time() - start:.2f}s")