auto switch mps device to cpu device
This commit is contained in:
@@ -6,11 +6,12 @@ import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo
|
||||
from lama_cleaner.helper import boxes_from_mask, resize_max_size, pad_img_to_modulo, switch_mps_device
|
||||
from lama_cleaner.schema import Config, HDStrategy
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
name = "base"
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
@@ -21,6 +22,7 @@ class InpaintModel:
|
||||
Args:
|
||||
device:
|
||||
"""
|
||||
device = switch_mps_device(self.name, device)
|
||||
self.device = device
|
||||
self.init_model(device, **kwargs)
|
||||
|
||||
|
||||
@@ -1131,6 +1131,7 @@ FCF_MODEL_URL = os.environ.get(
|
||||
|
||||
|
||||
class FcF(InpaintModel):
|
||||
name = "fcf"
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
|
||||
@@ -9,6 +9,7 @@ from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class InstructPix2Pix(DiffusionInpaintModel):
|
||||
name = "instruct_pix2pix"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
@@ -16,6 +16,7 @@ LAMA_MODEL_URL = os.environ.get(
|
||||
|
||||
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
|
||||
@@ -225,6 +225,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
|
||||
class LDM(InpaintModel):
|
||||
name = "ldm"
|
||||
pad_mod = 32
|
||||
|
||||
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||
|
||||
@@ -76,6 +76,7 @@ MANGA_LINE_MODEL_URL = os.environ.get(
|
||||
|
||||
|
||||
class Manga(InpaintModel):
|
||||
name = "manga"
|
||||
pad_mod = 16
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
|
||||
@@ -1401,6 +1401,7 @@ MAT_MODEL_URL = os.environ.get(
|
||||
|
||||
|
||||
class MAT(InpaintModel):
|
||||
name = "mat"
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
|
||||
@@ -2,12 +2,11 @@ import cv2
|
||||
from lama_cleaner.model.base import InpaintModel
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
flag_map = {
|
||||
"INPAINT_NS": cv2.INPAINT_NS,
|
||||
"INPAINT_TELEA": cv2.INPAINT_TELEA
|
||||
}
|
||||
flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
||||
|
||||
|
||||
class OpenCV2(InpaintModel):
|
||||
name = "cv2"
|
||||
pad_mod = 1
|
||||
|
||||
@staticmethod
|
||||
@@ -20,5 +19,10 @@ class OpenCV2(InpaintModel):
|
||||
mask: [H, W, 1]
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
cur_res = cv2.inpaint(image[:,:,::-1], mask, inpaintRadius=config.cv2_radius, flags=flag_map[config.cv2_flag])
|
||||
cur_res = cv2.inpaint(
|
||||
image[:, :, ::-1],
|
||||
mask,
|
||||
inpaintRadius=config.cv2_radius,
|
||||
flags=flag_map[config.cv2_flag],
|
||||
)
|
||||
return cur_res
|
||||
|
||||
@@ -11,6 +11,7 @@ from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class PaintByExample(DiffusionInpaintModel):
|
||||
name = "paint_by_example"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
|
||||
@@ -160,8 +160,10 @@ class SD(DiffusionInpaintModel):
|
||||
|
||||
|
||||
class SD15(SD):
|
||||
name = "sd1.5"
|
||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
|
||||
class SD2(SD):
|
||||
name = "sd2"
|
||||
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
||||
|
||||
@@ -203,6 +203,7 @@ def to_device(data, device):
|
||||
|
||||
|
||||
class ZITS(InpaintModel):
|
||||
name = "zits"
|
||||
min_size = 256
|
||||
pad_mod = 32
|
||||
pad_to_square = True
|
||||
|
||||
Reference in New Issue
Block a user