rename to iopaint

This commit is contained in:
Qing
2024-01-05 15:19:23 +08:00
parent f1f18aa6cd
commit a73e2a531f
101 changed files with 180 additions and 253 deletions

View File

@@ -0,0 +1,73 @@
from typing import Dict
from loguru import logger
from .anime_seg import AnimeSeg
from .gfpgan_plugin import GFPGANPlugin
from .interactive_seg import InteractiveSeg
from .realesrgan import RealESRGANUpscaler
from .remove_bg import RemoveBG
from .restoreformer import RestoreFormerPlugin
from ..const import InteractiveSegModel, Device, RealESRGANModel
def build_plugins(
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: RealESRGANModel,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
) -> Dict:
plugins = {}
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg(
interactive_seg_model, interactive_seg_device
)
if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin")
plugins[AnimeSeg.name] = AnimeSeg()
if enable_realesrgan:
logger.info(
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
)
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
realesrgan_model,
realesrgan_device,
no_half=no_half,
)
if enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
if enable_realesrgan:
logger.info("Use realesrgan as GFPGAN background upscaler")
else:
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
plugins[GFPGANPlugin.name] = GFPGANPlugin(
gfpgan_device,
upscaler=plugins.get(RealESRGANUpscaler.name, None),
)
if enable_restoreformer:
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
restoreformer_device,
upscaler=plugins.get(RealESRGANUpscaler.name, None),
)
return plugins

View File

@@ -0,0 +1,462 @@
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from iopaint.helper import load_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
return src
### RSU-7 ###
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7, self).__init__()
self.in_ch = in_ch
self.mid_ch = mid_ch
self.out_ch = out_ch
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
b, c, h, w = x.shape
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class ISNetDIS(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(ISNetDIS, self).__init__()
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage1 = RSU7(64, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
def forward(self, x):
hx = x
hxin = self.conv_in(hx)
hx = self.pool_in(hxin)
# stage 1
hx1 = self.stage1(hxin)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d1 = _upsample_like(d1, x)
return d1.sigmoid()
# 从小到大
ANIME_SEG_MODELS = {
"url": "https://github.com/Sanster/models/releases/download/isnetis/isnetis.pth",
"md5": "5f25479076b73074730ab8de9e8f2051",
}
class AnimeSeg(BasePlugin):
# Model from: https://github.com/SkyTNT/anime-segmentation
name = "AnimeSeg"
support_gen_image = True
support_gen_mask = True
def __init__(self):
super().__init__()
self.model = load_model(
ISNetDIS(),
ANIME_SEG_MODELS["url"],
"cpu",
ANIME_SEG_MODELS["md5"],
)
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
mask = self.forward(rgb_np_img)
mask = Image.fromarray(mask, mode="L")
h0, w0 = rgb_np_img.shape[0], rgb_np_img.shape[1]
empty = Image.new("RGBA", (w0, h0), 0)
img = Image.fromarray(rgb_np_img)
cutout = Image.composite(img, empty, mask)
return np.asarray(cutout)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
return self.forward(rgb_np_img)
@torch.inference_mode()
def forward(self, rgb_np_img):
s = 1024
h0, w0 = h, w = rgb_np_img.shape[0], rgb_np_img.shape[1]
if h > w:
h, w = s, int(s * w / h)
else:
h, w = int(s * h / w), s
ph, pw = s - h, s - w
tmpImg = np.zeros([s, s, 3], dtype=np.float32)
tmpImg[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = (
cv2.resize(rgb_np_img, (w, h)) / 255
)
tmpImg = tmpImg.transpose((2, 0, 1))
tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor)
mask = self.model(tmpImg)
mask = mask[0, :, ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
mask = cv2.resize(mask.cpu().numpy().transpose((1, 2, 0)), (w0, h0))
return (mask * 255).astype("uint8")

View File

@@ -0,0 +1,27 @@
from loguru import logger
import numpy as np
from iopaint.schema import RunPluginRequest
class BasePlugin:
name: str
support_gen_image: bool = False
support_gen_mask: bool = False
def __init__(self):
err_msg = self.check_dep()
if err_msg:
logger.error(err_msg)
exit(-1)
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return RGBA np image or BGR np image
...
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
# return GRAY or BGR np image, 255 means foreground, 0 means background
...
def check_dep(self):
...

View File

@@ -0,0 +1,74 @@
import cv2
import numpy as np
from loguru import logger
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
class GFPGANPlugin(BasePlugin):
name = "GFPGAN"
support_gen_image = True
def __init__(self, device, upscaler=None):
super().__init__()
from .gfpganer import MyGFPGANer
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
model_md5 = "94d735072630ab734561130a47bc44f8"
model_path = download_model(url, model_md5)
logger.info(f"GFPGAN model path: {model_path}")
import facexlib
if hasattr(facexlib.detection.retinaface, "device"):
facexlib.detection.retinaface.device = device
# Use GFPGAN for face enhancement
self.face_enhancer = MyGFPGANer(
model_path=model_path,
upscale=1,
arch="clean",
channel_multiplier=2,
device=device,
bg_upsampler=upscaler.model if upscaler is not None else None,
)
self.face_enhancer.face_helper.face_det.mean_tensor.to(device)
self.face_enhancer.face_helper.face_det = (
self.face_enhancer.face_helper.face_det.to(device)
)
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"GFPGAN input shape: {bgr_np_img.shape}")
_, _, bgr_output = self.face_enhancer.enhance(
bgr_np_img,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=weight,
)
logger.info(f"GFPGAN output shape: {bgr_output.shape}")
# try:
# if scale != 2:
# interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
# h, w = img.shape[0:2]
# output = cv2.resize(
# output,
# (int(w * scale / 2), int(h * scale / 2)),
# interpolation=interpolation,
# )
# except Exception as error:
# print("wrong scale input.", error)
return bgr_output
def check_dep(self):
try:
import gfpgan
except ImportError:
return (
"gfpgan is not installed, please install it first. pip install gfpgan"
)

View File

@@ -0,0 +1,84 @@
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from gfpgan import GFPGANv1Clean, GFPGANer
from torch.hub import get_dir
class MyGFPGANer(GFPGANer):
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(
self,
model_path,
upscale=2,
arch="clean",
channel_multiplier=2,
bg_upsampler=None,
device=None,
):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device is None
else device
)
# initialize the GFP-GAN
if arch == "clean":
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True,
)
elif arch == "RestoreFormer":
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
use_parse=True,
device=self.device,
model_rootpath=model_dir,
)
loadnet = torch.load(model_path)
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)

View File

@@ -0,0 +1,76 @@
import hashlib
import json
from typing import List
import cv2
import numpy as np
import torch
from loguru import logger
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
from iopaint.schema import RunPluginRequest
# 从小到大
SEGMENT_ANYTHING_MODELS = {
"vit_b": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"md5": "01ec64d29a2fca3f0661936605ae66f8",
},
"vit_l": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"md5": "0b3195507c641ddb6910d2bb5adee89c",
},
"vit_h": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"md5": "4b8939a88964f0f4ff5f5b2642c598a6",
},
"mobile_sam": {
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
"md5": "f3c0d8cda613564d499310dab6c812cd",
},
}
class InteractiveSeg(BasePlugin):
name = "InteractiveSeg"
support_gen_mask = True
def __init__(self, model_name, device):
super().__init__()
model_path = download_model(
SEGMENT_ANYTHING_MODELS[model_name]["url"],
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
)
logger.info(f"SegmentAnything model path: {model_path}")
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(device)
)
self.prev_img_md5 = None
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)
@torch.inference_mode()
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
input_point = []
input_label = []
for click in clicks:
x = click[0]
y = click[1]
input_point.append([x, y])
input_label.append(click[2])
if img_md5 and img_md5 != self.prev_img_md5:
self.prev_img_md5 = img_md5
self.predictor.set_image(rgb_np_img)
masks, scores, _ = self.predictor.predict(
point_coords=np.array(input_point),
point_labels=np.array(input_label),
multimask_output=False,
)
mask = masks[0].astype(np.uint8) * 255
return mask

View File

@@ -0,0 +1,100 @@
from enum import Enum
import cv2
import numpy as np
import torch
from loguru import logger
from iopaint.const import RealESRGANModel
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
class RealESRGANUpscaler(BasePlugin):
name = "RealESRGAN"
support_gen_image = True
def __init__(self, name, device, no_half=False):
super().__init__()
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
REAL_ESRGAN_MODELS = {
RealESRGANModel.realesr_general_x4v3: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
"scale": 4,
"model": lambda: SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
),
"model_md5": "91a7644643c884ee00737db24e478156",
},
RealESRGANModel.RealESRGAN_x4plus: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
"scale": 4,
"model": lambda: RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
"model_md5": "99ec365d4afad750833258a1a24f44ca",
},
RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
"scale": 4,
"model": lambda: RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6,
num_grow_ch=32,
scale=4,
),
"model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
},
}
if name not in REAL_ESRGAN_MODELS:
raise ValueError(f"Unknown RealESRGAN model name: {name}")
model_info = REAL_ESRGAN_MODELS[name]
model_path = download_model(model_info["url"], model_info["model_md5"])
logger.info(f"RealESRGAN model path: {model_path}")
self.model = RealESRGANer(
scale=model_info["scale"],
model_path=model_path,
model=model_info["model"](),
half=True if "cuda" in str(device) and not no_half else False,
tile=512,
tile_pad=10,
pre_pad=10,
device=device,
)
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
result = self.forward(bgr_np_img, req.scale)
logger.info(f"RealESRGAN output shape: {result.shape}")
return result
@torch.inference_mode()
def forward(self, bgr_np_img, scale: float):
# 输出是 BGR
upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
return upsampled
def check_dep(self):
try:
import realesrgan
except ImportError:
return "RealESRGAN is not installed, please install it first. pip install realesrgan"

View File

@@ -0,0 +1,49 @@
import os
import cv2
import numpy as np
from torch.hub import get_dir
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
class RemoveBG(BasePlugin):
name = "RemoveBG"
support_gen_mask = True
support_gen_image = True
def __init__(self):
super().__init__()
from rembg import new_session
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.environ["U2NET_HOME"] = model_dir
self.session = new_session(model_name="u2net")
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image
output = remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
from rembg import remove
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background
output = remove(bgr_np_img, session=self.session, only_mask=True)
return output
def check_dep(self):
try:
import rembg
except ImportError:
return (
"RemoveBG is not installed, please install it first. pip install rembg"
)

View File

@@ -0,0 +1,57 @@
import cv2
import numpy as np
from loguru import logger
from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest
class RestoreFormerPlugin(BasePlugin):
name = "RestoreFormer"
support_gen_image = True
def __init__(self, device, upscaler=None):
super().__init__()
from .gfpganer import MyGFPGANer
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
model_md5 = "eaeeff6c4a1caa1673977cb374e6f699"
model_path = download_model(url, model_md5)
logger.info(f"RestoreFormer model path: {model_path}")
import facexlib
if hasattr(facexlib.detection.retinaface, "device"):
facexlib.detection.retinaface.device = device
self.face_enhancer = MyGFPGANer(
model_path=model_path,
upscale=1,
arch="RestoreFormer",
channel_multiplier=2,
device=device,
bg_upsampler=upscaler.model if upscaler is not None else None,
)
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
weight = 0.5
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
logger.info(f"RestoreFormer input shape: {bgr_np_img.shape}")
_, _, bgr_output = self.face_enhancer.enhance(
bgr_np_img,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=weight,
)
logger.info(f"RestoreFormer output shape: {bgr_output.shape}")
return bgr_output
def check_dep(self):
try:
import gfpgan
except ImportError:
return (
"gfpgan is not installed, please install it first. pip install gfpgan"
)

View File

@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .build_sam import (
build_sam,
build_sam_vit_h,
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry,
)
from .predictor import SamPredictor

View File

@@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from functools import partial
from iopaint.plugins.segment_anything.modeling.tiny_vit_sam import TinyViT
from .modeling import (
ImageEncoderViT,
MaskDecoder,
PromptEncoder,
Sam,
TwoWayTransformer,
)
def build_sam_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
build_sam = build_sam_vit_h
def build_sam_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def build_sam_vit_t(checkpoint=None):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
mobile_sam = Sam(
image_encoder=TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 160, 320],
depths=[2, 2, 6, 2],
num_heads=[2, 4, 5, 10],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
mobile_sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
mobile_sam.load_state_dict(state_dict)
return mobile_sam
sam_model_registry = {
"default": build_sam,
"vit_h": build_sam,
"vit_l": build_sam_vit_l,
"vit_b": build_sam_vit_b,
"mobile_sam": build_sam_vit_t,
}
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
return sam

View File

@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .sam import Sam
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer

View File

@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import Type
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x

View File

@@ -0,0 +1,395 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Type
from .common import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x

View File

@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple, Type
from .common import LayerNorm2d
class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: Type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
tranformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for outptu
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x

View File

@@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch import nn
from typing import Any, Optional, Tuple, Type
from .common import LayerNorm2d
class PromptEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
image_embedding_size: Tuple[int, int],
input_image_size: Tuple[int, int],
mask_in_chans: int,
activation: Type[nn.Module] = nn.GELU,
) -> None:
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
input_image_size (int): The padded size of the image as input
to the image encoder, as (H, W).
mask_in_chans (int): The number of hidden channels used for
encoding input masks.
activation (nn.Module): The activation to use when encoding
input masks.
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embeds mask inputs."""
mask_embedding = self.mask_downscaling(masks)
return mask_embedding
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def _get_device(self) -> torch.device:
return self.point_embeddings[0].weight.device
def forward(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
and labels to embed.
boxes (torch.Tensor or none): boxes to embed
masks (torch.Tensor or none): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
grid = torch.ones((h, w), device=device, dtype=torch.float32)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
x_embed = x_embed / w
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
return pe.permute(2, 0, 1) # C x H x W
def forward_with_coords(
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
"""Positionally encode points that are not normalized to [0,1]."""
coords = coords_input.clone()
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
return self._pe_encoding(coords.to(torch.float)) # B x N x C

View File

@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from torch.nn import functional as F
from typing import Any, Dict, List, Tuple
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
class Sam(nn.Module):
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
) -> None:
"""
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input promts,
C is determiend by multimask_output, and (H, W) is the
original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions
of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if "point_coords" in image_record:
points = (image_record["point_coords"], image_record["point_labels"])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get("boxes", None),
masks=image_record.get("mask_inputs", None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record["image"].shape[-2:],
original_size=image_record["original_size"],
)
masks = masks > self.mask_threshold
outputs.append(
{
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks,
}
)
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x

View File

@@ -0,0 +1,822 @@
# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
# LeViT: (https://github.com/facebookresearch/levit)
# Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------
import collections
import itertools
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from typing import Tuple
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(itertools.repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class TimmDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(TimmDropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class Conv2d_BN(torch.nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1
):
super().__init__()
self.add_module(
"c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)
)
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module("bn", bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(
w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
stride=self.c.stride,
padding=self.c.padding,
dilation=self.c.dilation,
groups=self.c.groups,
)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class DropPath(TimmDropPath):
def __init__(self, drop_prob=None):
super().__init__(drop_prob=drop_prob)
self.drop_prob = drop_prob
def __repr__(self):
msg = super().__repr__()
msg += f"(drop_prob={self.drop_prob})"
return msg
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, resolution, activation):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
self.seq = nn.Sequential(
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
activation(),
Conv2d_BN(n // 2, n, 3, 2, 1),
)
def forward(self, x):
return self.seq(x)
class MBConv(nn.Module):
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
self.out_chans = out_chans
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
self.conv2 = Conv2d_BN(
self.hidden_chans,
self.hidden_chans,
ks=3,
stride=1,
pad=1,
groups=self.hidden_chans,
)
self.act2 = activation()
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
self.act3 = activation()
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act3(x)
return x
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 2
if out_dim == 320 or out_dim == 448 or out_dim == 576:
stride_c = 1
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
class ConvLayer(nn.Module):
def __init__(
self,
dim,
input_resolution,
depth,
activation,
drop_path=0.0,
downsample=None,
use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.0,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation
)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = act_layer()
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
key_dim,
num_heads=8,
attn_ratio=4,
resolution=(14, 14),
):
super().__init__()
# (h, w)
assert isinstance(resolution, tuple) and len(resolution) == 2
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(
torch.zeros(num_heads, len(attention_offsets))
)
self.register_buffer(
"attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False
)
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and hasattr(self, "ab"):
del self.ab
else:
self.register_buffer(
"ab",
self.attention_biases[:, self.attention_bias_idxs],
persistent=False,
)
def forward(self, x): # x (B,N,C)
B, N, _ = x.shape
# Normalization
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split(
[self.key_dim, self.key_dim, self.d], dim=3
)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale + (
self.attention_biases[:, self.attention_bias_idxs]
if self.training
else self.ab
)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class TinyViTBlock(nn.Module):
r"""TinyViT Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int, int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
activation: the activation function. Default: nn.GELU
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
local_conv_size=3,
activation=nn.GELU,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, "window_size must be greater than 0"
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
self.attn = Attention(
dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution
)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=mlp_activation,
drop=drop,
)
pad = local_conv_size // 2
self.local_conv = Conv2d_BN(
dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim
)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = (
x.view(B, nH, self.window_size, nW, self.window_size, C)
.transpose(2, 3)
.reshape(B * nH * nW, self.window_size * self.window_size, C)
)
x = self.attn(x)
# window reverse
x = (
x.view(B, nH, nW, self.window_size, self.window_size, C)
.transpose(2, 3)
.reshape(B, pH, pW, C)
)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
def extra_repr(self) -> str:
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
)
class BasicLayer(nn.Module):
"""A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
activation: the activation function. Default: nn.GELU
out_dim: the output dimension of the layer. Default: dim
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
drop=0.0,
drop_path=0.0,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList(
[
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
local_conv_size=local_conv_size,
activation=activation,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation
)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class TinyViT(nn.Module):
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
embed_dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
):
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
activation = nn.GELU
self.patch_embed = PatchEmbed(
in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation,
)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(
patches_resolution[0]
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1]
// (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
)
else:
layer = BasicLayer(
num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs,
)
self.layers.append(layer)
# Classifier head
self.norm_head = nn.LayerNorm(embed_dims[-1])
self.head = (
nn.Linear(embed_dims[-1], num_classes)
if num_classes > 0
else torch.nn.Identity()
)
# init weights
self.apply(self._init_weights)
self.set_layer_lr_decay(layer_lr_decay)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dims[-1],
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
)
def set_layer_lr_decay(self, layer_lr_decay):
decay_rate = layer_lr_decay
# layers -> blocks (depth)
depth = sum(self.depths)
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
# print("LR SCALES:", lr_scales)
def _set_lr_scale(m, scale):
for p in m.parameters():
p.lr_scale = scale
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
i = 0
for layer in self.layers:
for block in layer.blocks:
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
i += 1
if layer.downsample is not None:
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
assert i == depth
for m in [self.norm_head, self.head]:
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
for k, p in self.named_parameters():
p.param_name = k
def _check_lr_scale(m):
for p in m.parameters():
assert hasattr(p, "lr_scale"), p.param_name
self.apply(_check_lr_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"attention_biases"}
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)
x = self.layers[0](x)
start_i = 1
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
x = self.neck(x)
return x
def forward(self, x):
x = self.forward_features(x)
# x = self.norm_head(x)
# x = self.head(x)
return x

View File

@@ -0,0 +1,240 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor, nn
import math
from typing import Tuple, Type
from .common import MLPBlock
class TwoWayTransformer(nn.Module):
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
self.final_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attenion layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: Type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out

View File

@@ -0,0 +1,285 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from .modeling import Sam
from typing import Optional, Tuple
class SamPredictor:
def __init__(
self,
sam_model: Sam,
) -> None:
"""
Uses SAM to calculate the image embedding for an image, and then
allow repeated, efficient mask prediction given prompts.
Arguments:
sam_model (Sam): The model to use for mask prediction.
"""
super().__init__()
self.model = sam_model
from .utils.transforms import ResizeLongestSide
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()
def set_image(
self,
image: np.ndarray,
image_format: str = "RGB",
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Arguments:
image (np.ndarray): The image for calculating masks. Expects an
image in HWC uint8 format, with pixel values in [0, 255].
image_format (str): The color format of the image, in ['RGB', 'BGR'].
"""
assert image_format in [
"RGB",
"BGR",
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
if image_format != self.model.image_format:
image = image[..., ::-1]
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
None, :, :, :
]
self.set_torch_image(input_image_torch, image.shape[:2])
@torch.no_grad()
def set_torch_image(
self,
transformed_image: torch.Tensor,
original_image_size: Tuple[int, ...],
) -> None:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method. Expects the input
image to be already transformed to the format expected by the model.
Arguments:
transformed_image (torch.Tensor): The input image, with shape
1x3xHxW, which has been transformed with ResizeLongestSide.
original_image_size (tuple(int, int)): The size of the image
before transformation, in (H, W) format.
"""
assert (
len(transformed_image.shape) == 4
and transformed_image.shape[1] == 3
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
self.reset_image()
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features = self.model.image_encoder(input_image)
self.is_image_set = True
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
mask_input: Optional[np.ndarray] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Arguments:
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (np.ndarray or None): A length N array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
box (np.ndarray or None): A length 4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form 1xHxW, where
for SAM, H=W=256.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(np.ndarray): An array of length C containing the model's
predictions for the quality of each mask.
(np.ndarray): An array of shape CxHxW, where C is the number
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
if point_coords is not None:
assert (
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
coords_torch = torch.as_tensor(
point_coords, dtype=torch.float, device=self.device
)
labels_torch = torch.as_tensor(
point_labels, dtype=torch.int, device=self.device
)
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
mask_input_torch = torch.as_tensor(
mask_input, dtype=torch.float, device=self.device
)
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
)
masks = masks[0].detach().cpu().numpy()
iou_predictions = iou_predictions[0].detach().cpu().numpy()
low_res_masks = low_res_masks[0].detach().cpu().numpy()
return masks, iou_predictions, low_res_masks
@torch.no_grad()
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using ResizeLongestSide.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
box (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
# Embed prompts
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# Predict masks
low_res_masks, iou_predictions = self.model.mask_decoder(
image_embeddings=self.features,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.model.postprocess_masks(
low_res_masks, self.input_size, self.original_size
)
if not return_logits:
masks = masks > self.model.mask_threshold
return masks, iou_predictions, low_res_masks
def get_image_embedding(self) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self.is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self.features is not None
), "Features must exist if an image has been set."
return self.features
@property
def device(self) -> torch.device:
return self.model.device
def reset_image(self) -> None:
"""Resets the currently set image."""
self.is_image_set = False
self.features = None
self.orig_h = None
self.orig_w = None
self.input_h = None
self.input_w = None

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
from copy import deepcopy
from typing import Tuple
class ResizeLongestSide:
"""
Resizes images to longest side 'target_length', as well as provides
methods for resizing coordinates and boxes. Provides methods for
transforming both numpy array and batched torch tensors.
"""
def __init__(self, target_length: int) -> None:
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
target_size = self.get_preprocess_shape(
image.shape[0], image.shape[1], self.target_length
)
return np.array(resize(to_pil_image(image), target_size))
def apply_coords(
self, coords: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes(
self, boxes: np.ndarray, original_size: Tuple[int, ...]
) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
"""
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""
Expects batched images with shape BxCxHxW and float format. This
transformation may not exactly match apply_image. apply_image is
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(
image.shape[0], image.shape[1], self.target_length
)
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
def apply_coords_torch(
self, coords: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with length 2 in the last dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], self.target_length
)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def apply_boxes_torch(
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
) -> torch.Tensor:
"""
Expects a torch tensor with shape Bx4. Requires the original image
size in (H, W) format.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)
@staticmethod
def get_preprocess_shape(
oldh: int, oldw: int, long_side_length: int
) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)