auto switch mps device to cpu device
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import gc
|
||||
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model.fcf import FcF
|
||||
from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
@@ -13,8 +14,19 @@ from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.model.opencv2 import OpenCV2
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5": SD15, "cv2": OpenCV2, "manga": Manga,
|
||||
"sd2": SD2, "paint_by_example": PaintByExample, "instruct_pix2pix": InstructPix2Pix}
|
||||
models = {
|
||||
"lama": LaMa,
|
||||
"ldm": LDM,
|
||||
"zits": ZITS,
|
||||
"mat": MAT,
|
||||
"fcf": FcF,
|
||||
"sd1.5": SD15,
|
||||
"cv2": OpenCV2,
|
||||
"manga": Manga,
|
||||
"sd2": SD2,
|
||||
"paint_by_example": PaintByExample,
|
||||
"instruct_pix2pix": InstructPix2Pix,
|
||||
}
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -44,13 +56,15 @@ class ModelManager:
|
||||
if new_name == self.name:
|
||||
return
|
||||
try:
|
||||
if (torch.cuda.memory_allocated() > 0):
|
||||
if torch.cuda.memory_allocated() > 0:
|
||||
# Clear current loaded model from memory
|
||||
torch.cuda.empty_cache()
|
||||
del self.model
|
||||
gc.collect()
|
||||
|
||||
self.model = self.init_model(new_name, self.device, **self.kwargs)
|
||||
self.model = self.init_model(
|
||||
new_name, switch_mps_device(new_name, self.device), **self.kwargs
|
||||
)
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user