update
This commit is contained in:
@@ -1,10 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import traceback
|
||||
|
||||
import typer
|
||||
from typer import Option
|
||||
|
||||
from lama_cleaner.download import cli_download_model, scan_models
|
||||
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
import hashlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
import imghdr
|
||||
import io
|
||||
@@ -20,12 +28,7 @@ import torch
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SD15_MODELS,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
)
|
||||
from lama_cleaner.const import *
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
@@ -39,6 +42,8 @@ from lama_cleaner.plugins import (
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
@@ -103,23 +108,34 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app, expose_headers=["Content-Disposition", "X-seed"])
|
||||
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
|
||||
|
||||
sio_logger = logging.getLogger("sio-logger")
|
||||
sio_logger.setLevel(logging.ERROR)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
|
||||
|
||||
model: ModelManager = None
|
||||
thumb: FileManager = None
|
||||
output_dir: str = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
enable_file_manager: bool = False
|
||||
enable_auto_saving: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
@dataclass
|
||||
class GlobalConfig:
|
||||
model_manager: ModelManager = None
|
||||
file_manager: FileManager = None
|
||||
output_dir: Path = None
|
||||
input_image_path: Path = None
|
||||
disable_model_switch: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
@property
|
||||
def enable_auto_saving(self) -> bool:
|
||||
return self.output_dir is not None
|
||||
|
||||
@property
|
||||
def enable_file_manager(self) -> bool:
|
||||
return self.file_manager is not None
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
@@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents):
|
||||
|
||||
@app.route("/save_image", methods=["POST"])
|
||||
def save_image():
|
||||
if output_dir is None:
|
||||
if global_config.output_dir is None:
|
||||
return "--output-dir is None", 500
|
||||
|
||||
input = request.files
|
||||
@@ -143,7 +159,7 @@ def save_image():
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
save_path = str(global_config.output_dir / filename)
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != image.shape[:2]:
|
||||
@@ -157,7 +173,7 @@ def save_image():
|
||||
img_bytes = pil_to_bytes(
|
||||
pil_image,
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
with open(save_path, "wb") as fw:
|
||||
@@ -169,9 +185,11 @@ def save_image():
|
||||
@app.route("/medias/<tab>")
|
||||
def medias(tab):
|
||||
if tab == "image":
|
||||
response = make_response(jsonify(thumb.media_names), 200)
|
||||
response = make_response(jsonify(global_config.file_manager.media_names), 200)
|
||||
else:
|
||||
response = make_response(jsonify(thumb.output_media_names), 200)
|
||||
response = make_response(
|
||||
jsonify(global_config.file_manager.output_media_names), 200
|
||||
)
|
||||
# response.last_modified = thumb.modified_time[tab]
|
||||
# response.cache_control.no_cache = True
|
||||
# response.cache_control.max_age = 0
|
||||
@@ -182,8 +200,8 @@ def medias(tab):
|
||||
@app.route("/media/<tab>/<filename>")
|
||||
def media_file(tab, filename):
|
||||
if tab == "image":
|
||||
return send_from_directory(thumb.root_directory, filename)
|
||||
return send_from_directory(thumb.output_dir, filename)
|
||||
return send_from_directory(global_config.file_manager.root_directory, filename)
|
||||
return send_from_directory(global_config.file_manager.output_dir, filename)
|
||||
|
||||
|
||||
@app.route("/media_thumbnail/<tab>/<filename>")
|
||||
@@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename):
|
||||
if height:
|
||||
height = int(float(height))
|
||||
|
||||
directory = thumb.root_directory
|
||||
directory = global_config.file_manager.root_directory
|
||||
if tab == "output":
|
||||
directory = thumb.output_dir
|
||||
thumb_filename, (width, height) = thumb.get_thumbnail(
|
||||
directory = global_config.file_manager.output_dir
|
||||
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
|
||||
directory, filename, width, height
|
||||
)
|
||||
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
||||
@@ -257,13 +275,11 @@ def process():
|
||||
croper_y=form["croperY"],
|
||||
croper_height=form["croperHeight"],
|
||||
croper_width=form["croperWidth"],
|
||||
|
||||
use_extender=form["useExtender"],
|
||||
extender_x=form["extenderX"],
|
||||
extender_y=form["extenderY"],
|
||||
extender_height=form["extenderHeight"],
|
||||
extender_width=form["extenderWidth"],
|
||||
|
||||
sd_scale=form["sdScale"],
|
||||
sd_mask_blur=form["sdMaskBlur"],
|
||||
sd_strength=form["sdStrength"],
|
||||
@@ -294,7 +310,7 @@ def process():
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
res_np_img = model(image, mask, config)
|
||||
res_np_img = global_config.model_manager(image, mask, config)
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory. " in str(e):
|
||||
# NOTE: the string may change?
|
||||
@@ -322,7 +338,7 @@ def process():
|
||||
pil_to_bytes(
|
||||
Image.fromarray(res_np_img),
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
)
|
||||
@@ -345,7 +361,7 @@ def run_plugin():
|
||||
form = request.form
|
||||
files = request.files
|
||||
name = form["name"]
|
||||
if name not in plugins:
|
||||
if name not in global_config.plugins:
|
||||
return "Plugin not found", 500
|
||||
|
||||
origin_image_bytes = files["image"].read() # RGB
|
||||
@@ -359,7 +375,7 @@ def run_plugin():
|
||||
if name == InteractiveSeg.name:
|
||||
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
|
||||
form["img_md5"] = img_md5
|
||||
bgr_res = plugins[name](rgb_np_img, files, form)
|
||||
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
|
||||
except RuntimeError as e:
|
||||
torch.cuda.empty_cache()
|
||||
if "CUDA out of memory. " in str(e):
|
||||
@@ -401,7 +417,7 @@ def run_plugin():
|
||||
pil_to_bytes(
|
||||
Image.fromarray(rgb_res),
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
),
|
||||
@@ -414,41 +430,40 @@ def run_plugin():
|
||||
@app.route("/server_config", methods=["GET"])
|
||||
def get_server_config():
|
||||
return {
|
||||
"plugins": list(plugins.keys()),
|
||||
"enableFileManager": enable_file_manager,
|
||||
"enableAutoSaving": enable_auto_saving,
|
||||
"enableControlnet": model.sd_controlnet,
|
||||
"controlnetMethod": model.sd_controlnet_method,
|
||||
"disableModelSwitch": is_disable_model_switch,
|
||||
"plugins": list(global_config.plugins.keys()),
|
||||
"enableFileManager": global_config.enable_file_manager,
|
||||
"enableAutoSaving": global_config.enable_auto_saving,
|
||||
"enableControlnet": global_config.model_manager.sd_controlnet,
|
||||
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
|
||||
"disableModelSwitch": global_config.disable_model_switch,
|
||||
"isDesktop": global_config.is_desktop,
|
||||
}, 200
|
||||
|
||||
|
||||
@app.route("/models", methods=["GET"])
|
||||
def get_models():
|
||||
return [it.model_dump() for it in model.scan_models()]
|
||||
return [it.model_dump() for it in global_config.model_manager.scan_models()]
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.available_models[model.name].model_dump(), 200
|
||||
|
||||
|
||||
@app.route("/is_desktop")
|
||||
def get_is_desktop():
|
||||
return str(is_desktop), 200
|
||||
return (
|
||||
global_config.model_manager.current_model,
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
if is_disable_model_switch:
|
||||
if global_config.disable_model_switch:
|
||||
return "Switch model is disabled", 400
|
||||
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
if new_name == global_config.model_manager.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
global_config.model_manager.switch(new_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_message = f"{type(e).__name__} - {str(e)}"
|
||||
@@ -464,160 +479,230 @@ def index():
|
||||
|
||||
@app.route("/inputimage")
|
||||
def get_cli_input_image():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
if global_config.input_image_path:
|
||||
with open(global_config.input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
input_image_path,
|
||||
global_config.input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
download_name=Path(global_config.input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def build_plugins(args):
|
||||
global plugins
|
||||
if args.enable_interactive_seg:
|
||||
def build_plugins(
|
||||
enable_interactive_seg: bool,
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
enable_remove_bg: bool,
|
||||
enable_anime_seg: bool,
|
||||
enable_realesrgan: bool,
|
||||
realesrgan_device: Device,
|
||||
realesrgan_model: str,
|
||||
enable_gfpgan: bool,
|
||||
gfpgan_device: Device,
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
no_half: bool,
|
||||
):
|
||||
if enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
args.interactive_seg_model, args.interactive_seg_device
|
||||
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
interactive_seg_model, interactive_seg_device
|
||||
)
|
||||
|
||||
if args.enable_remove_bg:
|
||||
if enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
global_config.plugins[RemoveBG.name] = RemoveBG()
|
||||
|
||||
if args.enable_anime_seg:
|
||||
if enable_anime_seg:
|
||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||
plugins[AnimeSeg.name] = AnimeSeg()
|
||||
global_config.plugins[AnimeSeg.name] = AnimeSeg()
|
||||
|
||||
if args.enable_realesrgan:
|
||||
if enable_realesrgan:
|
||||
logger.info(
|
||||
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
|
||||
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
||||
)
|
||||
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||
args.realesrgan_model,
|
||||
args.realesrgan_device,
|
||||
no_half=args.realesrgan_no_half,
|
||||
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||
realesrgan_model,
|
||||
realesrgan_device,
|
||||
no_half=no_half,
|
||||
)
|
||||
|
||||
if args.enable_gfpgan:
|
||||
if enable_gfpgan:
|
||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||
if args.enable_realesrgan:
|
||||
if enable_realesrgan:
|
||||
logger.info("Use realesrgan as GFPGAN background upscaler")
|
||||
else:
|
||||
logger.info(
|
||||
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
||||
)
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
gfpgan_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
if args.enable_restoreformer:
|
||||
if enable_restoreformer:
|
||||
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
||||
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
args.restoreformer_device,
|
||||
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
||||
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
restoreformer_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
global enable_file_manager
|
||||
global is_desktop
|
||||
global thumb
|
||||
global output_dir
|
||||
global image_quality
|
||||
global enable_auto_saving
|
||||
|
||||
build_plugins(args)
|
||||
@typer_app.command(help="Install all plugins dependencies")
|
||||
def install_plugins_packages():
|
||||
from lama_cleaner.installer import install_plugins_package
|
||||
|
||||
image_quality = args.quality
|
||||
output_dir = args.output_dir
|
||||
install_plugins_package()
|
||||
|
||||
|
||||
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
|
||||
def download(
|
||||
model: str = Option(
|
||||
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
|
||||
),
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
cli_download_model(model, model_dir)
|
||||
|
||||
|
||||
@typer_app.command(help="List downloaded models")
|
||||
def list_model(
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
setup_model_dir(model_dir)
|
||||
scanned_models = scan_models()
|
||||
for it in scanned_models:
|
||||
print(it.name)
|
||||
|
||||
|
||||
@typer_app.command(help="Start lama cleaner server")
|
||||
def start(
|
||||
host: str = Option("127.0.0.1"),
|
||||
port: int = Option(8080),
|
||||
model: str = Option(
|
||||
DEFAULT_MODEL,
|
||||
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. "
|
||||
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
no_half: bool = Option(False, help=NO_HALF_HELP),
|
||||
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
||||
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
||||
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
gui: bool = Option(False, help=GUI_HELP),
|
||||
disable_model_switch: bool = Option(False),
|
||||
input: Path = Option(None, help=INPUT_HELP),
|
||||
output_dir: Path = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
quality: int = Option(95, help=QUALITY_HELP),
|
||||
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
||||
interactive_seg_model: InteractiveSegModel = Option(
|
||||
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||
enable_realesrgan: bool = Option(False),
|
||||
realesrgan_device: Device = Option(Device.cpu),
|
||||
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
|
||||
enable_gfpgan: bool = Option(False),
|
||||
gfpgan_device: Device = Option(Device.cpu),
|
||||
enable_restoreformer: bool = Option(False),
|
||||
restoreformer_device: Device = Option(Device.cpu),
|
||||
):
|
||||
global global_config
|
||||
dump_environment_info()
|
||||
|
||||
if input:
|
||||
if not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit()
|
||||
if input.is_dir():
|
||||
logger.info(f"Initialize file manager")
|
||||
file_manager = FileManager(app)
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
file_manager.output_dir = output_dir
|
||||
else:
|
||||
global_config.input_image_path = input
|
||||
|
||||
device = check_device(device)
|
||||
setup_model_dir(model_dir)
|
||||
|
||||
if local_files_only:
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
if model not in AVAILABLE_MODELS:
|
||||
scanned_models = scan_models()
|
||||
if model not in [it.name for it in scanned_models]:
|
||||
logger.error(
|
||||
f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
global_config.image_quality = quality
|
||||
global_config.disable_model_switch = disable_model_switch
|
||||
global_config.is_desktop = gui
|
||||
build_plugins(
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
no_half,
|
||||
)
|
||||
if output_dir:
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
logger.info(f"Output dir: {output_dir}")
|
||||
enable_auto_saving = True
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
logger.info(f"Image will auto save to output dir: {output_dir}")
|
||||
global_config.output_dir = output_dir
|
||||
|
||||
device = torch.device(args.device)
|
||||
is_disable_model_switch = args.disable_model_switch
|
||||
is_desktop = args.gui
|
||||
if is_disable_model_switch:
|
||||
logger.info(
|
||||
f"Start with --disable-model-switch, model switch on frontend is disable"
|
||||
)
|
||||
|
||||
if args.input and os.path.isdir(args.input):
|
||||
logger.info(f"Initialize file manager")
|
||||
thumb = FileManager(app)
|
||||
enable_file_manager = True
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
thumb.output_dir = Path(output_dir)
|
||||
# thumb.start()
|
||||
# try:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
# finally:
|
||||
# thumb.image_dir_observer.stop()
|
||||
# thumb.image_dir_observer.join()
|
||||
# thumb.output_dir_observer.stop()
|
||||
# thumb.output_dir_observer.join()
|
||||
|
||||
else:
|
||||
input_image_path = args.input
|
||||
|
||||
# 为了兼容性
|
||||
model_name_map = {
|
||||
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
||||
"anything4": "Sanster/anything-4.0-inpainting",
|
||||
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
"sd2": "stabilityai/stable-diffusion-2-inpainting",
|
||||
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
|
||||
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
|
||||
}
|
||||
|
||||
model = ModelManager(
|
||||
name=model_name_map.get(args.model, args.model),
|
||||
sd_controlnet=args.sd_controlnet,
|
||||
sd_controlnet_method=args.sd_controlnet_method,
|
||||
device=device,
|
||||
no_half=args.no_half,
|
||||
hf_access_token=args.hf_access_token,
|
||||
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||
cpu_offload=args.cpu_offload,
|
||||
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
||||
global_config.model_manager = ModelManager(
|
||||
name=model,
|
||||
device=torch.device(device),
|
||||
no_half=no_half,
|
||||
disable_nsfw=disable_nsfw_checker,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
cpu_offload=cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
if gui:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
ui = FlaskUI(
|
||||
app,
|
||||
socketio=socketio,
|
||||
width=app_width,
|
||||
height=app_height,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
close_server_on_exit=not args.no_gui_auto_close,
|
||||
width=1200,
|
||||
height=800,
|
||||
host=host,
|
||||
port=port,
|
||||
close_server_on_exit=True,
|
||||
idle_interval=60,
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
socketio.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
host=host,
|
||||
port=port,
|
||||
allow_unsafe_werkzeug=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user