get samplers from backend
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Dict, List, Union
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_safetensors_available
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
@@ -16,7 +16,6 @@ from diffusers import DiffusionPipeline, __version__
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
@@ -96,7 +95,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
|
||||
"""
|
||||
# Default kwargs from DiffusionPipeline
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
cache_dir = kwargs.pop("cache_dir", HF_HUB_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -246,7 +245,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
print(f"Skipping {attr}: not present in 2nd or 3d model")
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
module = getattr(final_pipe, attr)
|
||||
if isinstance(
|
||||
@@ -267,7 +265,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
else torch.load(checkpoint_path_1, map_location="cpu")
|
||||
)
|
||||
|
||||
if attr in ['vae', 'text_encoder']:
|
||||
if attr in ["vae", "text_encoder"]:
|
||||
print(f"Direct use theta1 {attr}: {checkpoint_path_1}")
|
||||
update_theta_0(theta_1)
|
||||
del theta_1
|
||||
@@ -348,7 +346,7 @@ pipe = CheckpointMergerPipeline.from_pretrained("runwayml/stable-diffusion-inpai
|
||||
merged_pipe = pipe.merge(
|
||||
[
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
#"SG161222/Realistic_Vision_V1.4",
|
||||
# "SG161222/Realistic_Vision_V1.4",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0",
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
],
|
||||
@@ -358,4 +356,6 @@ merged_pipe = pipe.merge(
|
||||
)
|
||||
|
||||
merged_pipe = merged_pipe.to(torch.float16)
|
||||
merged_pipe.save_pretrained("dreamlike-diffusion-1.0-inpainting", safe_serialization=True)
|
||||
merged_pipe.save_pretrained(
|
||||
"dreamlike-diffusion-1.0-inpainting", safe_serialization=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user