update
This commit is contained in:
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
|
||||
from lama_cleaner.runtime import setup_model_dir
|
||||
from lama_cleaner.schema import (
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
@@ -16,16 +17,8 @@ from lama_cleaner.schema import (
|
||||
)
|
||||
|
||||
|
||||
def cli_download_model(model: str, model_dir: str):
|
||||
if os.path.isfile(model_dir):
|
||||
raise ValueError(f"invalid --model-dir: {model_dir} is a file")
|
||||
|
||||
if not os.path.exists(model_dir):
|
||||
logger.info(f"Create model cache directory: {model_dir}")
|
||||
Path(model_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
os.environ["XDG_CACHE_HOME"] = model_dir
|
||||
|
||||
def cli_download_model(model: str, model_dir: Path):
|
||||
setup_model_dir(model_dir)
|
||||
from lama_cleaner.model import models
|
||||
|
||||
if model in models:
|
||||
@@ -38,7 +31,7 @@ def cli_download_model(model: str, model_dir: str):
|
||||
|
||||
downloaded_path = DiffusionPipeline.download(
|
||||
pretrained_model_name=model,
|
||||
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
resume_download=True,
|
||||
)
|
||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||
@@ -101,7 +94,7 @@ def scan_inpaint_models() -> List[ModelInfo]:
|
||||
from lama_cleaner.model import models
|
||||
|
||||
for name, m in models.items():
|
||||
if m.is_erase_model:
|
||||
if m.is_erase_model and m.is_downloaded():
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
|
||||
Reference in New Issue
Block a user