lots update

This commit is contained in:
Qing
2023-12-27 22:00:07 +08:00
parent 0ba6c121e0
commit f0b852725f
33 changed files with 4085 additions and 1000 deletions

View File

@@ -5,23 +5,23 @@ from typing import List
from loguru import logger
from pathlib import Path
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
from lama_cleaner.runtime import setup_model_dir
from lama_cleaner.schema import (
ModelInfo,
ModelType,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
from lama_cleaner.const import (
DEFAULT_MODEL_DIR,
DIFFUSERS_SD_CLASS_NAME,
DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME,
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
)
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
from lama_cleaner.model_info import ModelInfo, ModelType
from lama_cleaner.runtime import setup_model_dir
def cli_download_model(model: str, model_dir: Path):
setup_model_dir(model_dir)
from lama_cleaner.model import models
if model in models:
if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
@@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path):
logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline
downloaded_path = DiffusionPipeline.download(
downloaded_path = handle_from_pretrained_exceptions(
DiffusionPipeline.download,
pretrained_model_name=model,
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
variant="fp16",
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
@@ -43,21 +44,33 @@ def folder_name_to_show_name(name: str) -> str:
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
res = []
for it in cache_dir.glob(f"*.*"):
for it in stable_diffusion_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SD_INPAINT
model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
if "sdxl" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL
else:
model_type = ModelType.DIFFUSERS_SD
model_type = ModelType.DIFFUSERS_SD
res.append(
ModelInfo(
name=it.name,
path=str(it.absolute()),
model_type=model_type,
is_single_file_diffusers=True,
)
)
for it in stable_diffusion_xl_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SDXL
res.append(
ModelInfo(
name=it.name,
@@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
if _class_name == DIFFUSERS_SD_CLASS_NAME:
if "PowerPaint" in name:
model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT