diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 3593010..ee44206 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -23,9 +23,7 @@ def md5sum(filename): def switch_mps_device(model_name, device): - if model_name not in MPS_SUPPORT_MODELS and ( - device == "mps" or device == torch.device("mps") - ): + if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps": logger.info(f"{model_name} not support mps, switch to cpu") return torch.device("cpu") return device