This commit is contained in:
Qing
2023-11-16 11:08:34 +08:00
parent 8f942e27c4
commit 0cfec489b7
7 changed files with 63 additions and 28 deletions

View File

@@ -12,8 +12,9 @@ from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config
class CPUTextEncoderWrapper:
class CPUTextEncoderWrapper(torch.nn.Module):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)