diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index d983f29..fced353 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -39,13 +39,14 @@ from lama_cleaner.schema import Config, SDSampler class CPUTextEncoderWrapper: def __init__(self, text_encoder, torch_dtype): + self.config = text_encoder.config self.text_encoder = text_encoder.to(torch.device('cpu'), non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.torch_dtype = torch_dtype - def __call__(self, x): + def __call__(self, x, **kwargs): input_device = x.device - return [self.text_encoder(x.to(self.text_encoder.device))[0].to(input_device).to(self.torch_dtype)] + return [self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0].to(input_device).to(self.torch_dtype)] class SD(InpaintModel):