clean code: get_torch_dtype; mps use float16 by default

This commit is contained in:
Qing
2024-01-08 23:53:20 +08:00
parent 6f4ce66793
commit a71c3fbe1b
8 changed files with 23 additions and 28 deletions

View File

@@ -1,4 +1,3 @@
import copy
import gc
import math
import random
@@ -994,3 +993,12 @@ def handle_from_pretrained_exceptions(func, **kwargs):
raise e
except Exception as e:
raise e
def get_torch_dtype(device, no_half: bool):
device = str(device)
use_fp16 = not no_half
use_gpu = device == "cuda"
if device in ["cuda", "mps"] and use_fp16:
return use_gpu, torch.float16
return use_gpu, torch.float32