mat support float16
This commit is contained in:
@@ -1880,7 +1880,10 @@ class MAT(InpaintModel):
|
||||
seed = 240 # pick up a random number
|
||||
set_seed(seed)
|
||||
|
||||
self.torch_dtype = torch.float16
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
use_gpu = "cuda" in str(device) and torch.cuda.is_available()
|
||||
self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
G = Generator(
|
||||
z_dim=512,
|
||||
c_dim=0,
|
||||
|
||||
Reference in New Issue
Block a user