auto switch mps device to cpu device

This commit is contained in:
Qing
2023-02-11 13:30:09 +08:00
parent f9b5dcbfd7
commit 8f8bcfe0f4
15 changed files with 52 additions and 19 deletions

View File

@@ -1,6 +1,7 @@
import torch
import gc
from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM
@@ -13,8 +14,19 @@ from lama_cleaner.model.zits import ZITS
from lama_cleaner.model.opencv2 import OpenCV2
from lama_cleaner.schema import Config
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
models = {
"lama": LaMa,
"ldm": LDM,
"zits": ZITS,
"mat": MAT,
"fcf": FcF,
"sd1.5": SD15,
"cv2": OpenCV2,
"manga": Manga,
"sd2": SD2,
"paint_by_example": PaintByExample,
"instruct_pix2pix": InstructPix2Pix,
}
class ModelManager:
@@ -44,13 +56,15 @@ class ModelManager:
if new_name == self.name:
return
try:
if (torch.cuda.memory_allocated() > 0):
if torch.cuda.memory_allocated() > 0:
# Clear current loaded model from memory
torch.cuda.empty_cache()
del self.model
gc.collect()
self.model = self.init_model(new_name, self.device, **self.kwargs)
self.model = self.init_model(
new_name, switch_mps_device(new_name, self.device), **self.kwargs
)
self.name = new_name
except NotImplementedError as e:
raise e