add mask tab
This commit is contained in:
@@ -19,7 +19,6 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
import uvicorn
|
||||
from PIL import Image
|
||||
from fastapi import APIRouter, FastAPI, Request, UploadFile
|
||||
@@ -127,7 +126,7 @@ def api_middleware(app: FastAPI):
|
||||
"allow_headers": ["*"],
|
||||
"allow_origins": ["*"],
|
||||
"allow_credentials": True,
|
||||
"expose_headers": ["X-Seed"]
|
||||
"expose_headers": ["X-Seed"],
|
||||
}
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
||||
@@ -159,7 +158,8 @@ class Api:
|
||||
|
||||
# fmt: off
|
||||
self.add_api_route("/api/v1/gen-info", self.api_geninfo, methods=["POST"], response_model=GenInfoResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"], response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/server-config", self.api_server_config, methods=["GET"],
|
||||
response_model=ServerConfigResponse)
|
||||
self.add_api_route("/api/v1/model", self.api_current_model, methods=["GET"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/model", self.api_switch_model, methods=["POST"], response_model=ModelInfo)
|
||||
self.add_api_route("/api/v1/inputimage", self.api_input_image, methods=["GET"])
|
||||
@@ -361,6 +361,7 @@ class Api:
|
||||
return FileManager(
|
||||
app=self.app,
|
||||
input_dir=self.config.input,
|
||||
mask_dir=self.config.mask_dir,
|
||||
output_dir=self.config.output_dir,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import webbrowser
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from fastapi import FastAPI
|
||||
@@ -120,6 +120,9 @@ def start(
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
input: Optional[Path] = Option(None, help=INPUT_HELP),
|
||||
mask_dir: Optional[Path] = Option(
|
||||
None, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
output_dir: Optional[Path] = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
@@ -145,8 +148,11 @@ def start(
|
||||
if input and not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit(-1)
|
||||
if mask_dir and not mask_dir.exists():
|
||||
logger.error(f"invalid --mask-dir: {mask_dir} not exists")
|
||||
exit(-1)
|
||||
if input and input.is_dir() and not output_dir:
|
||||
logger.error(f"invalid --output-dir: must be set when --input is a directory")
|
||||
logger.error("invalid --output-dir: --output-dir must be set when --input is a directory")
|
||||
exit(-1)
|
||||
if output_dir:
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
@@ -154,6 +160,8 @@ def start(
|
||||
if not output_dir.exists():
|
||||
logger.info(f"Create output directory {output_dir}")
|
||||
output_dir.mkdir(parents=True)
|
||||
if mask_dir:
|
||||
mask_dir = mask_dir.expanduser().absolute()
|
||||
|
||||
model_dir = model_dir.expanduser().absolute()
|
||||
|
||||
@@ -192,6 +200,7 @@ def start(
|
||||
cpu_textencoder=cpu_textencoder if device == Device.cuda else False,
|
||||
device=device,
|
||||
input=input,
|
||||
mask_dir=mask_dir,
|
||||
output_dir=output_dir,
|
||||
quality=quality,
|
||||
enable_interactive_seg=enable_interactive_seg,
|
||||
|
||||
@@ -63,7 +63,7 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
||||
|
||||
SD_BRUSHNET_CHOICES: List[str] = [
|
||||
"Sanster/brushnet_random_mask",
|
||||
"Sanster/brushnet_segmentation_mask"
|
||||
"Sanster/brushnet_segmentation_mask",
|
||||
]
|
||||
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
@@ -99,6 +99,10 @@ OUTPUT_DIR_HELP = """
|
||||
Result images will be saved to output directory automatically.
|
||||
"""
|
||||
|
||||
MASK_DIR_HELP = """
|
||||
You can view masks in FileManager
|
||||
"""
|
||||
|
||||
INPUT_HELP = """
|
||||
If input is image, it will be loaded by default.
|
||||
If input is directory, you can browse and select image in file manager.
|
||||
|
||||
@@ -4,7 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
from fastapi import FastAPI, UploadFile, HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
|
||||
from ..schema import MediasResponse, MediaTab
|
||||
@@ -16,9 +16,10 @@ from .utils import aspect_to_string, generate_filename, glob_img
|
||||
|
||||
|
||||
class FileManager:
|
||||
def __init__(self, app: FastAPI, input_dir: Path, output_dir: Path):
|
||||
def __init__(self, app: FastAPI, input_dir: Path, mask_dir: Path, output_dir: Path):
|
||||
self.app = app
|
||||
self.input_dir: Path = input_dir
|
||||
self.mask_dir: Path = mask_dir
|
||||
self.output_dir: Path = output_dir
|
||||
|
||||
self.image_dir_filenames = []
|
||||
@@ -63,6 +64,8 @@ class FileManager:
|
||||
return self.input_dir
|
||||
elif tab == "output":
|
||||
return self.output_dir
|
||||
elif tab == "mask":
|
||||
return self.mask_dir
|
||||
else:
|
||||
raise HTTPException(status_code=422, detail=f"tab not found: {tab}")
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ class ApiConfig(BaseModel):
|
||||
cpu_textencoder: bool
|
||||
device: Device
|
||||
input: Optional[Path]
|
||||
mask_dir: Optional[Path]
|
||||
output_dir: Optional[Path]
|
||||
quality: int
|
||||
enable_interactive_seg: bool
|
||||
@@ -436,7 +437,7 @@ class RunPluginRequest(BaseModel):
|
||||
scale: float = Field(2.0, description="Scale for upscaling")
|
||||
|
||||
|
||||
MediaTab = Literal["input", "output"]
|
||||
MediaTab = Literal["input", "output", "mask"]
|
||||
|
||||
|
||||
class MediasResponse(BaseModel):
|
||||
|
||||
@@ -3,10 +3,11 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import mimetypes
|
||||
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
from iopaint.schema import (
|
||||
Device,
|
||||
@@ -78,40 +79,43 @@ def load_config(p: Path) -> WebConfig:
|
||||
|
||||
|
||||
def save_config(
|
||||
host,
|
||||
port,
|
||||
model,
|
||||
model_dir,
|
||||
no_half,
|
||||
low_mem,
|
||||
cpu_offload,
|
||||
disable_nsfw_checker,
|
||||
local_files_only,
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
inbrowser,
|
||||
host,
|
||||
port,
|
||||
model,
|
||||
model_dir,
|
||||
no_half,
|
||||
low_mem,
|
||||
cpu_offload,
|
||||
disable_nsfw_checker,
|
||||
local_files_only,
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
mask_dir,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
remove_bg_model,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
inbrowser,
|
||||
):
|
||||
config = WebConfig(**locals())
|
||||
if str(config.input) == ".":
|
||||
config.input = None
|
||||
if str(config.output_dir) == ".":
|
||||
config.output_dir = None
|
||||
if str(config.mask_dir) == ".":
|
||||
config.mask_dir = None
|
||||
config.model = config.model.strip()
|
||||
print(config.model_dump_json(indent=4))
|
||||
if config.input and not os.path.exists(config.input):
|
||||
@@ -166,7 +170,7 @@ def main(config_file: Path):
|
||||
model = gr.Textbox(
|
||||
init_config.model,
|
||||
label="Current Model. Model will be automatically downloaded. "
|
||||
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
||||
"You can select a model in Recommended Models or Downloaded Models or manually enter the SD/SDXL model ID from HuggingFace, for example, runwayml/stable-diffusion-inpainting.",
|
||||
)
|
||||
|
||||
device = gr.Radio(
|
||||
@@ -207,6 +211,10 @@ def main(config_file: Path):
|
||||
init_config.output_dir,
|
||||
label=f"Output directory. {OUTPUT_DIR_HELP}",
|
||||
)
|
||||
mask_dir = gr.Textbox(
|
||||
init_config.mask_dir,
|
||||
label=f"Mask directory. {MASK_DIR_HELP}",
|
||||
)
|
||||
|
||||
with gr.Tab("Plugins"):
|
||||
with gr.Row():
|
||||
@@ -288,6 +296,7 @@ def main(config_file: Path):
|
||||
cpu_textencoder,
|
||||
device,
|
||||
input,
|
||||
mask_dir,
|
||||
output_dir,
|
||||
quality,
|
||||
enable_interactive_seg,
|
||||
|
||||
Reference in New Issue
Block a user