get samplers from backend

This commit is contained in:
Qing
2024-01-02 14:34:36 +08:00
parent a2fd5bb3ea
commit f38be37f8c
14 changed files with 141 additions and 101 deletions

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Union
import torch
from diffusers.utils import is_safetensors_available
from huggingface_hub.constants import HF_HUB_CACHE
if is_safetensors_available():
import safetensors.torch
@@ -16,7 +16,6 @@ from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
)
@@ -96,7 +95,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -246,7 +245,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
print(f"Skipping {attr}: not present in 2nd or 3d model")
continue
try:
module = getattr(final_pipe, attr)
if isinstance(
@@ -267,7 +265,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
else torch.load(checkpoint_path_1, map_location="cpu")
)
if attr in ['vae', 'text_encoder']:
if attr in ["vae", "text_encoder"]:
print(f"Direct use theta1 {attr}: {checkpoint_path_1}")
update_theta_0(theta_1)
del theta_1
@@ -348,7 +346,7 @@ pipe = CheckpointMergerPipeline.from_pretrained("runwayml/stable-diffusion-inpai
merged_pipe = pipe.merge(
[
"runwayml/stable-diffusion-inpainting",
#"SG161222/Realistic_Vision_V1.4",
# "SG161222/Realistic_Vision_V1.4",
"dreamlike-art/dreamlike-diffusion-1.0",
"runwayml/stable-diffusion-v1-5",
],
@@ -358,4 +356,6 @@ merged_pipe = pipe.merge(
)
merged_pipe = merged_pipe.to(torch.float16)
merged_pipe.save_pretrained("dreamlike-diffusion-1.0-inpainting", safe_serialization=True)
merged_pipe.save_pretrained(
"dreamlike-diffusion-1.0-inpainting", safe_serialization=True
)