pass pipe_components when turn control on/off
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..utils import torch_gc
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
class CPUTextEncoderWrapper(PreTrainedModel):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
super().__init__(text_encoder.config)
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user