This commit is contained in:
@@ -23,9 +23,7 @@ def md5sum(filename):
|
||||
|
||||
|
||||
def switch_mps_device(model_name, device):
|
||||
if model_name not in MPS_SUPPORT_MODELS and (
|
||||
device == "mps" or device == torch.device("mps")
|
||||
):
|
||||
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
|
||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
||||
Reference in New Issue
Block a user