fix torch_dtype & cpu_text_encoder

This commit is contained in:
Qing
2024-02-10 10:59:06 +08:00
parent 180b7d6c70
commit 9aa5a7e0ba
4 changed files with 13 additions and 4 deletions

View File

@@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel):
def __init__(self, text_encoder, torch_dtype):
super().__init__(text_encoder.config)
self.config = text_encoder.config
self._device = text_encoder.device
# cpu not support float16
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
@@ -30,3 +31,11 @@ class CPUTextEncoderWrapper(PreTrainedModel):
@property
def dtype(self):
return self.torch_dtype
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return self._device