make brushnet work

This commit is contained in:
Qing
2024-04-12 11:07:41 +08:00
parent 35f12d5b9b
commit 0a262fa811
14 changed files with 3408 additions and 56 deletions

View File

@@ -1,3 +1,4 @@
import glob
import json
import os
from functools import lru_cache
@@ -92,7 +93,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
else:
model_type = ModelType.DIFFUSERS_SDXL
except ValueError as e:
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
if "but got torch.Size([320, 4, 3, 3])" in str(e):
model_type = ModelType.DIFFUSERS_SDXL
else:
raise e
@@ -192,7 +193,9 @@ def scan_diffusers_models() -> List[ModelInfo]:
cache_dir = Path(HF_HUB_CACHE)
# logger.info(f"Scanning diffusers models in {cache_dir}")
diffusers_model_names = []
for it in cache_dir.glob("**/*/model_index.json"):
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
try:
data = json.load(f)
@@ -238,7 +241,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
diffusers_model_names = []
for it in cache_dir.glob("**/*/model_index.json"):
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
for it in model_index_files:
it = Path(it)
with open(it, "r", encoding="utf-8") as f:
try:
data = json.load(f)