pass pipe_components when turn control on/off

This commit is contained in:
Qing
2024-01-09 22:58:21 +08:00
parent db4a6f4547
commit 05a15b2e1f
6 changed files with 26 additions and 9 deletions

View File

@@ -1,10 +1,12 @@
import torch
from transformers import PreTrainedModel
from ..utils import torch_gc
class CPUTextEncoderWrapper(torch.nn.Module):
class CPUTextEncoderWrapper(PreTrainedModel):
def __init__(self, text_encoder, torch_dtype):
super().__init__()
super().__init__(text_encoder.config)
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)