mat support float16

This commit is contained in:
Qing
2023-03-26 11:45:39 +08:00
parent eb304ba696
commit 1433d21b9f
2 changed files with 118 additions and 4 deletions

View File

@@ -1880,7 +1880,10 @@ class MAT(InpaintModel):
seed = 240 # pick up a random number
set_seed(seed)
self.torch_dtype = torch.float16
fp16 = not kwargs.get("no_half", False)
use_gpu = "cuda" in str(device) and torch.cuda.is_available()
self.torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
G = Generator(
z_dim=512,
c_dim=0,