add --no-half arg

This commit is contained in:
Qing
2023-01-03 21:30:33 +08:00
parent 6cfc7c30f1
commit 59ee89bd34
5 changed files with 6 additions and 3 deletions

View File

@@ -15,8 +15,9 @@ class PaintByExample(InpaintModel):
min_size = 512
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs['no_half']
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu else torch.float32
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype,