create cli.py
This commit is contained in:
@@ -1,13 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
|
||||
import typer
|
||||
from typer import Option
|
||||
|
||||
from lama_cleaner.download import cli_download_model, scan_models
|
||||
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
import hashlib
|
||||
import traceback
|
||||
@@ -35,14 +28,11 @@ from lama_cleaner.model_manager import ModelManager
|
||||
from lama_cleaner.plugins import (
|
||||
InteractiveSeg,
|
||||
RemoveBG,
|
||||
RealESRGANUpscaler,
|
||||
GFPGANPlugin,
|
||||
RestoreFormerPlugin,
|
||||
AnimeSeg,
|
||||
build_plugins,
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
@@ -492,7 +482,20 @@ def get_cli_input_image():
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def build_plugins(
|
||||
def start(
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
no_half: bool,
|
||||
cpu_offload: bool,
|
||||
disable_nsfw_checker,
|
||||
cpu_textencoder: bool,
|
||||
device: Device,
|
||||
gui: bool,
|
||||
disable_model_switch: bool,
|
||||
input: Path,
|
||||
output_dir: Path,
|
||||
quality: int,
|
||||
enable_interactive_seg: bool,
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
@@ -505,123 +508,7 @@ def build_plugins(
|
||||
gfpgan_device: Device,
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
no_half: bool,
|
||||
):
|
||||
if enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
interactive_seg_model, interactive_seg_device
|
||||
)
|
||||
|
||||
if enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
global_config.plugins[RemoveBG.name] = RemoveBG()
|
||||
|
||||
if enable_anime_seg:
|
||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||
global_config.plugins[AnimeSeg.name] = AnimeSeg()
|
||||
|
||||
if enable_realesrgan:
|
||||
logger.info(
|
||||
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
||||
)
|
||||
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||
realesrgan_model,
|
||||
realesrgan_device,
|
||||
no_half=no_half,
|
||||
)
|
||||
|
||||
if enable_gfpgan:
|
||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||
if enable_realesrgan:
|
||||
logger.info("Use realesrgan as GFPGAN background upscaler")
|
||||
else:
|
||||
logger.info(
|
||||
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
||||
)
|
||||
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
gfpgan_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
if enable_restoreformer:
|
||||
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
||||
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
restoreformer_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
|
||||
@typer_app.command(help="Install all plugins dependencies")
|
||||
def install_plugins_packages():
|
||||
from lama_cleaner.installer import install_plugins_package
|
||||
|
||||
install_plugins_package()
|
||||
|
||||
|
||||
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
|
||||
def download(
|
||||
model: str = Option(
|
||||
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
|
||||
),
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
cli_download_model(model, model_dir)
|
||||
|
||||
|
||||
@typer_app.command(help="List downloaded models")
|
||||
def list_model(
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
setup_model_dir(model_dir)
|
||||
scanned_models = scan_models()
|
||||
for it in scanned_models:
|
||||
print(it.name)
|
||||
|
||||
|
||||
@typer_app.command(help="Start lama cleaner server")
|
||||
def start(
|
||||
host: str = Option("127.0.0.1"),
|
||||
port: int = Option(8080),
|
||||
model: str = Option(
|
||||
DEFAULT_MODEL,
|
||||
help=f"Available erase models: [{', '.join(AVAILABLE_MODELS)}]. "
|
||||
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
no_half: bool = Option(False, help=NO_HALF_HELP),
|
||||
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
||||
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
||||
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
gui: bool = Option(False, help=GUI_HELP),
|
||||
disable_model_switch: bool = Option(False),
|
||||
input: Path = Option(None, help=INPUT_HELP),
|
||||
output_dir: Path = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
quality: int = Option(95, help=QUALITY_HELP),
|
||||
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
||||
interactive_seg_model: InteractiveSegModel = Option(
|
||||
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||
enable_realesrgan: bool = Option(False),
|
||||
realesrgan_device: Device = Option(Device.cpu),
|
||||
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
|
||||
enable_gfpgan: bool = Option(False),
|
||||
gfpgan_device: Device = Option(Device.cpu),
|
||||
enable_restoreformer: bool = Option(False),
|
||||
restoreformer_device: Device = Option(Device.cpu),
|
||||
):
|
||||
global global_config
|
||||
dump_environment_info()
|
||||
|
||||
if input:
|
||||
if not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
@@ -637,24 +524,11 @@ def start(
|
||||
else:
|
||||
global_config.input_image_path = input
|
||||
|
||||
device = check_device(device)
|
||||
setup_model_dir(model_dir)
|
||||
|
||||
if local_files_only:
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
scanned_models = scan_models()
|
||||
if model not in [it.name for it in scanned_models]:
|
||||
logger.error(
|
||||
f"invalid model: {model} not exists. Available models: {[it.name for it in scanned_models]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
global_config.image_quality = quality
|
||||
global_config.disable_model_switch = disable_model_switch
|
||||
global_config.is_desktop = gui
|
||||
build_plugins(
|
||||
global_config,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
|
||||
Reference in New Issue
Block a user