Qing
2024-04-12 13:09:37 +08:00
parent f71e9cfb26
commit 35f12d5b9b
4 changed files with 13 additions and 9 deletions

View File

@@ -25,11 +25,11 @@ def cli_download_model(model: str):
if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
logger.info("Done.")
elif model == ANYTEXT_NAME:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
logger.info("Done.")
else:
logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline
@@ -60,7 +60,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
config_files=get_config_files(),
original_config_file=get_config_files()['v1']
)
model_type = ModelType.DIFFUSERS_SD_INPAINT
except ValueError as e:
@@ -84,7 +84,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
model_abs_path,
load_safety_checker=False,
num_in_channels=9,
config_files=get_config_files(),
original_config_file=get_config_files()['xl'],
)
if model.unet.config.in_channels == 9:
# https://github.com/huggingface/diffusers/issues/6610
@@ -113,7 +113,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
pass
res = []
for it in stable_diffusion_dir.glob(f"*.*"):
for it in stable_diffusion_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
model_abs_path = str(it.absolute())
@@ -144,7 +144,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
except:
pass
for it in stable_diffusion_xl_dir.glob(f"*.*"):
for it in stable_diffusion_xl_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
model_abs_path = str(it.absolute())