clean code: get_torch_dtype; mps use float16 by default
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
@@ -994,3 +993,12 @@ def handle_from_pretrained_exceptions(func, **kwargs):
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def get_torch_dtype(device, no_half: bool):
|
||||
device = str(device)
|
||||
use_fp16 = not no_half
|
||||
use_gpu = device == "cuda"
|
||||
if device in ["cuda", "mps"] and use_fp16:
|
||||
return use_gpu, torch.float16
|
||||
return use_gpu, torch.float32
|
||||
|
||||
Reference in New Issue
Block a user