add enable_low_mem
This commit is contained in:
@@ -1002,3 +1002,18 @@ def get_torch_dtype(device, no_half: bool):
|
||||
if device in ["cuda", "mps"] and use_fp16:
|
||||
return use_gpu, torch.float16
|
||||
return use_gpu, torch.float32
|
||||
|
||||
|
||||
def enable_low_mem(pipe, enable: bool):
|
||||
if torch.backends.mps.is_available():
|
||||
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
|
||||
if enable:
|
||||
pipe.enable_attention_slicing("max")
|
||||
else:
|
||||
# https://huggingface.co/docs/diffusers/optimization/mps
|
||||
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
if enable:
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
Reference in New Issue
Block a user