fix CPUTextEncoder

This commit is contained in:
Qing
2022-12-04 22:04:36 +08:00
parent 1100e5badb
commit 4a739aaae7

View File

@@ -39,13 +39,14 @@ from lama_cleaner.schema import Config, SDSampler
class CPUTextEncoderWrapper: class CPUTextEncoderWrapper:
def __init__(self, text_encoder, torch_dtype): def __init__(self, text_encoder, torch_dtype):
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True) self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
def __call__(self, x): def __call__(self, x, **kwargs):
input_device = x.device input_device = x.device
return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device).to(self.torch_dtype)] return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)]
class SD(InpaintModel): class SD(InpaintModel):