This commit is contained in:
Qing
2023-12-24 15:32:27 +08:00
parent 0e5e16ba20
commit 371db2d771
31 changed files with 441 additions and 439 deletions

View File

@@ -1,10 +1,18 @@
#!/usr/bin/env python3
import json
import os
import hashlib
import traceback
import typer
from typer import Option
from lama_cleaner.download import cli_download_model, scan_models
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import hashlib
import traceback
from dataclasses import dataclass
import imghdr
import io
@@ -20,12 +28,7 @@ import torch
from PIL import Image
from loguru import logger
from lama_cleaner.const import (
SD15_MODELS,
SD_CONTROLNET_CHOICES,
SDXL_CONTROLNET_CHOICES,
SD2_CONTROLNET_CHOICES,
)
from lama_cleaner.const import *
from lama_cleaner.file_manager import FileManager
from lama_cleaner.model.utils import torch_gc
from lama_cleaner.model_manager import ModelManager
@@ -39,6 +42,8 @@ from lama_cleaner.plugins import (
)
from lama_cleaner.schema import Config
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
@@ -103,23 +108,34 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
app.config["JSON_AS_ASCII"] = False
CORS(app, expose_headers=["Content-Disposition", "X-seed"])
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
sio_logger = logging.getLogger("sio-logger")
sio_logger.setLevel(logging.ERROR)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
model: ModelManager = None
thumb: FileManager = None
output_dir: str = None
device = None
input_image_path: str = None
is_disable_model_switch: bool = False
enable_file_manager: bool = False
enable_auto_saving: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
@dataclass
class GlobalConfig:
model_manager: ModelManager = None
file_manager: FileManager = None
output_dir: Path = None
input_image_path: Path = None
disable_model_switch: bool = False
is_desktop: bool = False
image_quality: int = 95
plugins = {}
@property
def enable_auto_saving(self) -> bool:
return self.output_dir is not None
@property
def enable_file_manager(self) -> bool:
return self.file_manager is not None
global_config = GlobalConfig()
def get_image_ext(img_bytes):
@@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents):
@app.route("/save_image", methods=["POST"])
def save_image():
if output_dir is None:
if global_config.output_dir is None:
return "--output-dir is None", 500
input = request.files
@@ -143,7 +159,7 @@ def save_image():
origin_image_bytes = input["image"].read() # RGB
ext = get_image_ext(origin_image_bytes)
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
save_path = os.path.join(output_dir, filename)
save_path = str(global_config.output_dir / filename)
if alpha_channel is not None:
if alpha_channel.shape[:2] != image.shape[:2]:
@@ -157,7 +173,7 @@ def save_image():
img_bytes = pil_to_bytes(
pil_image,
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
with open(save_path, "wb") as fw:
@@ -169,9 +185,11 @@ def save_image():
@app.route("/medias/<tab>")
def medias(tab):
if tab == "image":
response = make_response(jsonify(thumb.media_names), 200)
response = make_response(jsonify(global_config.file_manager.media_names), 200)
else:
response = make_response(jsonify(thumb.output_media_names), 200)
response = make_response(
jsonify(global_config.file_manager.output_media_names), 200
)
# response.last_modified = thumb.modified_time[tab]
# response.cache_control.no_cache = True
# response.cache_control.max_age = 0
@@ -182,8 +200,8 @@ def medias(tab):
@app.route("/media/<tab>/<filename>")
def media_file(tab, filename):
if tab == "image":
return send_from_directory(thumb.root_directory, filename)
return send_from_directory(thumb.output_dir, filename)
return send_from_directory(global_config.file_manager.root_directory, filename)
return send_from_directory(global_config.file_manager.output_dir, filename)
@app.route("/media_thumbnail/<tab>/<filename>")
@@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename):
if height:
height = int(float(height))
directory = thumb.root_directory
directory = global_config.file_manager.root_directory
if tab == "output":
directory = thumb.output_dir
thumb_filename, (width, height) = thumb.get_thumbnail(
directory = global_config.file_manager.output_dir
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
directory, filename, width, height
)
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
@@ -257,13 +275,11 @@ def process():
croper_y=form["croperY"],
croper_height=form["croperHeight"],
croper_width=form["croperWidth"],
use_extender=form["useExtender"],
extender_x=form["extenderX"],
extender_y=form["extenderY"],
extender_height=form["extenderHeight"],
extender_width=form["extenderWidth"],
sd_scale=form["sdScale"],
sd_mask_blur=form["sdMaskBlur"],
sd_strength=form["sdStrength"],
@@ -294,7 +310,7 @@ def process():
start = time.time()
try:
res_np_img = model(image, mask, config)
res_np_img = global_config.model_manager(image, mask, config)
except RuntimeError as e:
if "CUDA out of memory. " in str(e):
# NOTE: the string may change?
@@ -322,7 +338,7 @@ def process():
pil_to_bytes(
Image.fromarray(res_np_img),
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
)
@@ -345,7 +361,7 @@ def run_plugin():
form = request.form
files = request.files
name = form["name"]
if name not in plugins:
if name not in global_config.plugins:
return "Plugin not found", 500
origin_image_bytes = files["image"].read() # RGB
@@ -359,7 +375,7 @@ def run_plugin():
if name == InteractiveSeg.name:
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
form["img_md5"] = img_md5
bgr_res = plugins[name](rgb_np_img, files, form)
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
except RuntimeError as e:
torch.cuda.empty_cache()
if "CUDA out of memory. " in str(e):
@@ -401,7 +417,7 @@ def run_plugin():
pil_to_bytes(
Image.fromarray(rgb_res),
ext,
quality=image_quality,
quality=global_config.image_quality,
exif_infos=exif_infos,
)
),
@@ -414,41 +430,40 @@ def run_plugin():
@app.route("/server_config", methods=["GET"])
def get_server_config():
return {
"plugins": list(plugins.keys()),
"enableFileManager": enable_file_manager,
"enableAutoSaving": enable_auto_saving,
"enableControlnet": model.sd_controlnet,
"controlnetMethod": model.sd_controlnet_method,
"disableModelSwitch": is_disable_model_switch,
"plugins": list(global_config.plugins.keys()),
"enableFileManager": global_config.enable_file_manager,
"enableAutoSaving": global_config.enable_auto_saving,
"enableControlnet": global_config.model_manager.sd_controlnet,
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
"disableModelSwitch": global_config.disable_model_switch,
"isDesktop": global_config.is_desktop,
}, 200
@app.route("/models", methods=["GET"])
def get_models():
return [it.model_dump() for it in model.scan_models()]
return [it.model_dump() for it in global_config.model_manager.scan_models()]
@app.route("/model")
def current_model():
return model.available_models[model.name].model_dump(), 200
@app.route("/is_desktop")
def get_is_desktop():
return str(is_desktop), 200
return (
global_config.model_manager.current_model,
200,
)
@app.route("/model", methods=["POST"])
def switch_model():
if is_disable_model_switch:
if global_config.disable_model_switch:
return "Switch model is disabled", 400
new_name = request.form.get("name")
if new_name == model.name:
if new_name == global_config.model_manager.name:
return "Same model", 200
try:
model.switch(new_name)
global_config.model_manager.switch(new_name)
except Exception as e:
traceback.print_exc()
error_message = f"{type(e).__name__} - {str(e)}"
@@ -464,160 +479,230 @@ def index():
@app.route("/inputimage")
def get_cli_input_image():
if input_image_path:
with open(input_image_path, "rb") as f:
if global_config.input_image_path:
with open(global_config.input_image_path, "rb") as f:
image_in_bytes = f.read()
return send_file(
input_image_path,
global_config.input_image_path,
as_attachment=True,
download_name=Path(input_image_path).name,
download_name=Path(global_config.input_image_path).name,
mimetype=f"image/{get_image_ext(image_in_bytes)}",
)
else:
return "No Input Image"
def build_plugins(args):
global plugins
if args.enable_interactive_seg:
def build_plugins(
enable_interactive_seg: bool,
interactive_seg_model: InteractiveSegModel,
interactive_seg_device: Device,
enable_remove_bg: bool,
enable_anime_seg: bool,
enable_realesrgan: bool,
realesrgan_device: Device,
realesrgan_model: str,
enable_gfpgan: bool,
gfpgan_device: Device,
enable_restoreformer: bool,
restoreformer_device: Device,
no_half: bool,
):
if enable_interactive_seg:
logger.info(f"Initialize {InteractiveSeg.name} plugin")
plugins[InteractiveSeg.name] = InteractiveSeg(
args.interactive_seg_model, args.interactive_seg_device
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
interactive_seg_model, interactive_seg_device
)
if args.enable_remove_bg:
if enable_remove_bg:
logger.info(f"Initialize {RemoveBG.name} plugin")
plugins[RemoveBG.name] = RemoveBG()
global_config.plugins[RemoveBG.name] = RemoveBG()
if args.enable_anime_seg:
if enable_anime_seg:
logger.info(f"Initialize {AnimeSeg.name} plugin")
plugins[AnimeSeg.name] = AnimeSeg()
global_config.plugins[AnimeSeg.name] = AnimeSeg()
if args.enable_realesrgan:
if enable_realesrgan:
logger.info(
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
)
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
args.realesrgan_model,
args.realesrgan_device,
no_half=args.realesrgan_no_half,
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
realesrgan_model,
realesrgan_device,
no_half=no_half,
)
if args.enable_gfpgan:
if enable_gfpgan:
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
if args.enable_realesrgan:
if enable_realesrgan:
logger.info("Use realesrgan as GFPGAN background upscaler")
else:
logger.info(
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
)
plugins[GFPGANPlugin.name] = GFPGANPlugin(
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
gfpgan_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
)
if args.enable_restoreformer:
if enable_restoreformer:
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
args.restoreformer_device,
upscaler=plugins.get(RealESRGANUpscaler.name, None),
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
restoreformer_device,
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
)
def main(args):
global model
global device
global input_image_path
global is_disable_model_switch
global enable_file_manager
global is_desktop
global thumb
global output_dir
global image_quality
global enable_auto_saving
build_plugins(args)
@typer_app.command(help="Install all plugins dependencies")
def install_plugins_packages():
from lama_cleaner.installer import install_plugins_package
image_quality = args.quality
output_dir = args.output_dir
install_plugins_package()
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
def download(
model: str = Option(
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
),
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
cli_download_model(model, model_dir)
@typer_app.command(help="List downloaded models")
def list_model(
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
):
setup_model_dir(model_dir)
scanned_models = scan_models()
for it in scanned_models:
print(it.name)
@typer_app.command(help="Start lama cleaner server")
def start(
host: str = Option("127.0.0.1"),
port: int = Option(8080),
model: str = Option(
DEFAULT_MODEL,
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. "
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
),
model_dir: Path = Option(
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
),
no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
device: Device = Option(Device.cpu),
gui: bool = Option(False, help=GUI_HELP),
disable_model_switch: bool = Option(False),
input: Path = Option(None, help=INPUT_HELP),
output_dir: Path = Option(
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
),
quality: int = Option(95, help=QUALITY_HELP),
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
interactive_seg_model: InteractiveSegModel = Option(
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
),
interactive_seg_device: Device = Option(Device.cpu),
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
enable_realesrgan: bool = Option(False),
realesrgan_device: Device = Option(Device.cpu),
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
enable_gfpgan: bool = Option(False),
gfpgan_device: Device = Option(Device.cpu),
enable_restoreformer: bool = Option(False),
restoreformer_device: Device = Option(Device.cpu),
):
global global_config
dump_environment_info()
if input:
if not input.exists():
logger.error(f"invalid --input: {input} not exists")
exit()
if input.is_dir():
logger.info(f"Initialize file manager")
file_manager = FileManager(app)
app.config["THUMBNAIL_MEDIA_ROOT"] = input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
file_manager.output_dir = output_dir
else:
global_config.input_image_path = input
device = check_device(device)
setup_model_dir(model_dir)
if local_files_only:
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
if model not in AVAILABLE_MODELS:
scanned_models = scan_models()
if model not in [it.name for it in scanned_models]:
logger.error(
f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
)
exit()
global_config.image_quality = quality
global_config.disable_model_switch = disable_model_switch
global_config.is_desktop = gui
build_plugins(
enable_interactive_seg,
interactive_seg_model,
interactive_seg_device,
enable_remove_bg,
enable_anime_seg,
enable_realesrgan,
realesrgan_device,
realesrgan_model,
enable_gfpgan,
gfpgan_device,
enable_restoreformer,
restoreformer_device,
no_half,
)
if output_dir:
output_dir = os.path.abspath(output_dir)
logger.info(f"Output dir: {output_dir}")
enable_auto_saving = True
output_dir = output_dir.expanduser().absolute()
logger.info(f"Image will auto save to output dir: {output_dir}")
global_config.output_dir = output_dir
device = torch.device(args.device)
is_disable_model_switch = args.disable_model_switch
is_desktop = args.gui
if is_disable_model_switch:
logger.info(
f"Start with --disable-model-switch, model switch on frontend is disable"
)
if args.input and os.path.isdir(args.input):
logger.info(f"Initialize file manager")
thumb = FileManager(app)
enable_file_manager = True
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
output_dir, "lama_cleaner_thumbnails"
)
thumb.output_dir = Path(output_dir)
# thumb.start()
# try:
# while True:
# time.sleep(1)
# finally:
# thumb.image_dir_observer.stop()
# thumb.image_dir_observer.join()
# thumb.output_dir_observer.stop()
# thumb.output_dir_observer.join()
else:
input_image_path = args.input
# 为了兼容性
model_name_map = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
"sd2": "stabilityai/stable-diffusion-2-inpainting",
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
}
model = ModelManager(
name=model_name_map.get(args.model, args.model),
sd_controlnet=args.sd_controlnet,
sd_controlnet_method=args.sd_controlnet_method,
device=device,
no_half=args.no_half,
hf_access_token=args.hf_access_token,
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder,
cpu_offload=args.cpu_offload,
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
global_config.model_manager = ModelManager(
name=model,
device=torch.device(device),
no_half=no_half,
disable_nsfw=disable_nsfw_checker,
sd_cpu_textencoder=cpu_textencoder,
cpu_offload=cpu_offload,
callback=diffuser_callback,
)
if args.gui:
app_width, app_height = args.gui_size
if gui:
from flaskwebgui import FlaskUI
ui = FlaskUI(
app,
socketio=socketio,
width=app_width,
height=app_height,
host=args.host,
port=args.port,
close_server_on_exit=not args.no_gui_auto_close,
width=1200,
height=800,
host=host,
port=port,
close_server_on_exit=True,
idle_interval=60,
)
ui.run()
else:
socketio.run(
app,
host=args.host,
port=args.port,
debug=args.debug,
host=host,
port=port,
allow_unsafe_werkzeug=True,
)