add enable_low_mem

This commit is contained in:
Qing
2024-01-08 23:54:20 +08:00
parent a71c3fbe1b
commit a49c3f86d3
7 changed files with 25 additions and 4 deletions

View File

@@ -1002,3 +1002,18 @@ def get_torch_dtype(device, no_half: bool):
if device in ["cuda", "mps"] and use_fp16:
return use_gpu, torch.float16
return use_gpu, torch.float32
def enable_low_mem(pipe, enable: bool):
if torch.backends.mps.is_available():
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
if enable:
pipe.enable_attention_slicing("max")
else:
# https://huggingface.co/docs/diffusers/optimization/mps
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
pipe.enable_attention_slicing()
if enable:
pipe.vae.enable_tiling()