update
This commit is contained in:
@@ -10,11 +10,8 @@ from lama_cleaner.parse_args import parse_args
|
||||
|
||||
|
||||
def entry_point():
|
||||
args = parse_args()
|
||||
if args is None:
|
||||
return
|
||||
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
||||
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
||||
from lama_cleaner.server import main
|
||||
from lama_cleaner.server import typer_app
|
||||
|
||||
main(args)
|
||||
typer_app()
|
||||
|
||||
@@ -103,6 +103,5 @@ if __name__ == "__main__":
|
||||
device=device,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
hf_access_token="123"
|
||||
)
|
||||
benchmark(model, args.times, args.empty_cache)
|
||||
|
||||
@@ -21,19 +21,17 @@ AVAILABLE_MODELS = [
|
||||
"zits",
|
||||
"mat",
|
||||
"fcf",
|
||||
"manga",
|
||||
"cv2",
|
||||
"sd1.5",
|
||||
"sdxl",
|
||||
"anything4",
|
||||
"realisticVision1.4",
|
||||
"cv2",
|
||||
"manga",
|
||||
"sd2",
|
||||
"sdxl",
|
||||
"paint_by_example",
|
||||
"instruct_pix2pix",
|
||||
"kandinsky2.2",
|
||||
"sdxl",
|
||||
]
|
||||
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||
DIFFUSERS_MODEL_FP16_REVERSION = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"Sanster/anything-4.0-inpainting",
|
||||
@@ -46,26 +44,22 @@ AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||
DEFAULT_DEVICE = "cuda"
|
||||
|
||||
NO_HALF_HELP = """
|
||||
Using full precision model.
|
||||
If your generate result is always black or green, use this argument. (sd/paint_by_exmaple)
|
||||
Using full precision(fp32) model.
|
||||
If your diffusion model generate result is always black or green, use this argument.
|
||||
"""
|
||||
|
||||
CPU_OFFLOAD_HELP = """
|
||||
Offloads all models to CPU, significantly reducing vRAM usage. (sd/paint_by_example)
|
||||
Offloads diffusion model's weight to CPU RAM, significantly reducing vRAM usage.
|
||||
"""
|
||||
|
||||
DISABLE_NSFW_HELP = """
|
||||
Disable NSFW checker. (sd/paint_by_example)
|
||||
Disable NSFW checker for diffusion model.
|
||||
"""
|
||||
|
||||
SD_CPU_TEXTENCODER_HELP = """
|
||||
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
||||
CPU_TEXTENCODER_HELP = """
|
||||
Run diffusion models text encoder on CPU to reduce vRAM usage.
|
||||
"""
|
||||
|
||||
SD_CONTROLNET_HELP = """
|
||||
Run Stable Diffusion normal or inpainting model with ControlNet.
|
||||
"""
|
||||
DEFAULT_SD_CONTROLNET_METHOD = "lllyasviel/control_v11p_sd15_canny"
|
||||
SD_CONTROLNET_CHOICES = [
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
# "lllyasviel/control_v11p_sd15_seg",
|
||||
@@ -74,46 +68,36 @@ SD_CONTROLNET_CHOICES = [
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
]
|
||||
|
||||
DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers"
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
]
|
||||
|
||||
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
|
||||
SDXL_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-openpose-sdxl-1.0",
|
||||
"destitech/controlnet-inpaint-dreamer-sdxl"
|
||||
"destitech/controlnet-inpaint-dreamer-sdxl",
|
||||
"diffusers/controlnet-canny-sdxl-1.0",
|
||||
"diffusers/controlnet-canny-sdxl-1.0-mid",
|
||||
"diffusers/controlnet-canny-sdxl-1.0-small"
|
||||
"diffusers/controlnet-canny-sdxl-1.0-small",
|
||||
"diffusers/controlnet-depth-sdxl-1.0",
|
||||
"diffusers/controlnet-depth-sdxl-1.0-mid",
|
||||
"diffusers/controlnet-depth-sdxl-1.0-small",
|
||||
]
|
||||
|
||||
SD_LOCAL_MODEL_HELP = """
|
||||
Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path.
|
||||
"""
|
||||
|
||||
LOCAL_FILES_ONLY_HELP = """
|
||||
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
||||
"""
|
||||
|
||||
ENABLE_XFORMERS_HELP = """
|
||||
Enable xFormers optimizations. Requires xformers package has been installed. See: https://github.com/facebookresearch/xformers (sd/paint_by_example)
|
||||
When loading diffusion models, using local files only, not connect to HuggingFace server.
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_DIR = os.getenv(
|
||||
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
|
||||
)
|
||||
MODEL_DIR_HELP = """
|
||||
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
||||
MODEL_DIR_HELP = f"""
|
||||
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model download to {DEFAULT_MODEL_DIR}
|
||||
"""
|
||||
|
||||
OUTPUT_DIR_HELP = """
|
||||
Result images will be saved to output directory automatically without confirmation.
|
||||
Result images will be saved to output directory automatically.
|
||||
"""
|
||||
|
||||
INPUT_HELP = """
|
||||
@@ -125,37 +109,45 @@ GUI_HELP = """
|
||||
Launch Lama Cleaner as desktop app
|
||||
"""
|
||||
|
||||
NO_GUI_AUTO_CLOSE_HELP = """
|
||||
Prevent backend auto close after the GUI window closed.
|
||||
"""
|
||||
|
||||
QUALITY_HELP = """
|
||||
Quality of image encoding, 0-100. Default is 95, higher quality will generate larger file size.
|
||||
"""
|
||||
|
||||
|
||||
class RealESRGANModelName(str, Enum):
|
||||
class Choices(str, Enum):
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [member.value for member in cls]
|
||||
|
||||
|
||||
class RealESRGANModel(Choices):
|
||||
realesr_general_x4v3 = "realesr-general-x4v3"
|
||||
RealESRGAN_x4plus = "RealESRGAN_x4plus"
|
||||
RealESRGAN_x4plus_anime_6B = "RealESRGAN_x4plus_anime_6B"
|
||||
|
||||
|
||||
RealESRGANModelNameList = [e.value for e in RealESRGANModelName]
|
||||
class Device(Choices):
|
||||
cpu = "cpu"
|
||||
cuda = "cuda"
|
||||
mps = "mps"
|
||||
|
||||
|
||||
class InteractiveSegModel(Choices):
|
||||
vit_b = "vit_b"
|
||||
vit_l = "vit_l"
|
||||
vit_h = "vit_h"
|
||||
mobile_sam = "mobile_sam"
|
||||
|
||||
|
||||
INTERACTIVE_SEG_HELP = "Enable interactive segmentation using Segment Anything."
|
||||
INTERACTIVE_SEG_MODEL_HELP = "Model size: vit_b < vit_l < vit_h. Bigger model size means better segmentation but slower speed."
|
||||
AVAILABLE_INTERACTIVE_SEG_MODELS = ["vit_b", "vit_l", "vit_h", "vit_t"]
|
||||
AVAILABLE_INTERACTIVE_SEG_DEVICES = ["cuda", "cpu", "mps"]
|
||||
REMOVE_BG_HELP = "Enable remove background. Always run on CPU"
|
||||
ANIMESEG_HELP = "Enable anime segmentation. Always run on CPU"
|
||||
REALESRGAN_HELP = "Enable realesrgan super resolution"
|
||||
REALESRGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
||||
GFPGAN_HELP = (
|
||||
"Enable GFPGAN face restore. To enhance background, use with --enable-realesrgan"
|
||||
)
|
||||
GFPGAN_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
||||
RESTOREFORMER_HELP = "Enable RestoreFormer face restore. To enhance background, use with --enable-realesrgan"
|
||||
RESTOREFORMER_AVAILABLE_DEVICES = ["cpu", "cuda", "mps"]
|
||||
GIF_HELP = "Enable GIF plugin. Make GIF to compare original and cleaned image"
|
||||
|
||||
|
||||
@@ -164,8 +156,6 @@ class Config(BaseModel):
|
||||
port: int = 8080
|
||||
model: str = DEFAULT_MODEL
|
||||
sd_local_model_path: str = None
|
||||
sd_controlnet: bool = False
|
||||
sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD
|
||||
device: str = DEFAULT_DEVICE
|
||||
gui: bool = False
|
||||
no_gui_auto_close: bool = False
|
||||
@@ -173,7 +163,6 @@ class Config(BaseModel):
|
||||
cpu_offload: bool = False
|
||||
disable_nsfw: bool = False
|
||||
sd_cpu_textencoder: bool = False
|
||||
enable_xformers: bool = False
|
||||
local_files_only: bool = False
|
||||
model_dir: str = DEFAULT_MODEL_DIR
|
||||
input: str = None
|
||||
@@ -186,7 +175,7 @@ class Config(BaseModel):
|
||||
enable_anime_seg: bool = False
|
||||
enable_realesrgan: bool = False
|
||||
realesrgan_device: str = "cpu"
|
||||
realesrgan_model: str = RealESRGANModelName.realesr_general_x4v3.value
|
||||
realesrgan_model: str = RealESRGANModel.realesr_general_x4v3.value
|
||||
realesrgan_no_half: bool = False
|
||||
enable_gfpgan: bool = False
|
||||
gfpgan_device: str = "cpu"
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
|
||||
from lama_cleaner.runtime import setup_model_dir
|
||||
from lama_cleaner.schema import (
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
@@ -16,16 +17,8 @@ from lama_cleaner.schema import (
|
||||
)
|
||||
|
||||
|
||||
def cli_download_model(model: str, model_dir: str):
|
||||
if os.path.isfile(model_dir):
|
||||
raise ValueError(f"invalid --model-dir: {model_dir} is a file")
|
||||
|
||||
if not os.path.exists(model_dir):
|
||||
logger.info(f"Create model cache directory: {model_dir}")
|
||||
Path(model_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
os.environ["XDG_CACHE_HOME"] = model_dir
|
||||
|
||||
def cli_download_model(model: str, model_dir: Path):
|
||||
setup_model_dir(model_dir)
|
||||
from lama_cleaner.model import models
|
||||
|
||||
if model in models:
|
||||
@@ -38,7 +31,7 @@ def cli_download_model(model: str, model_dir: str):
|
||||
|
||||
downloaded_path = DiffusionPipeline.download(
|
||||
pretrained_model_name=model,
|
||||
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
resume_download=True,
|
||||
)
|
||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||
@@ -101,7 +94,7 @@ def scan_inpaint_models() -> List[ModelInfo]:
|
||||
from lama_cleaner.model import models
|
||||
|
||||
for name, m in models.items():
|
||||
if m.is_erase_model:
|
||||
if m.is_erase_model and m.is_downloaded():
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
|
||||
@@ -41,7 +41,7 @@ class InpaintModel:
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def is_downloaded() -> bool:
|
||||
...
|
||||
return False
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self, image, mask, config: Config):
|
||||
@@ -67,6 +67,8 @@ class InpaintModel:
|
||||
|
||||
logger.info(f"final forward pad size: {pad_image.shape}")
|
||||
|
||||
image, mask = self.forward_pre_process(image, mask, config)
|
||||
|
||||
result = self.forward(pad_image, pad_mask, config)
|
||||
result = result[0:origin_height, 0:origin_width, :]
|
||||
|
||||
@@ -77,6 +79,9 @@ class InpaintModel:
|
||||
result = result * (mask / 255) + image[:, :, ::-1] * (1 - (mask / 255))
|
||||
return result
|
||||
|
||||
def forward_pre_process(self, image, mask, config):
|
||||
return image, mask
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
return result, image, mask
|
||||
|
||||
@@ -400,6 +405,13 @@ class DiffusionInpaintModel(InpaintModel):
|
||||
scheduler = get_scheduler(sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
def forward_pre_process(self, image, mask, config):
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
|
||||
return image, mask
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.sd_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
@@ -17,14 +17,6 @@ from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
||||
|
||||
# 为了兼容性
|
||||
controlnet_name_map = {
|
||||
"control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose",
|
||||
"control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint",
|
||||
"control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
}
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
name = "controlnet"
|
||||
@@ -49,9 +41,6 @@ class ControlNet(DiffusionInpaintModel):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
model_info: ModelInfo = kwargs["model_info"]
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
sd_controlnet_method = controlnet_name_map.get(
|
||||
sd_controlnet_method, sd_controlnet_method
|
||||
)
|
||||
|
||||
self.model_info = model_info
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
@@ -113,12 +102,6 @@ class ControlNet(DiffusionInpaintModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
@@ -162,10 +145,6 @@ class ControlNet(DiffusionInpaintModel):
|
||||
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
control_image = self._get_control_image(image, mask)
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
@@ -190,8 +169,3 @@ class ControlNet(DiffusionInpaintModel):
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
@@ -31,30 +31,15 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix",
|
||||
revision="fp16" if use_gpu and fp16 else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs
|
||||
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
|
||||
self.model.enable_attention_slicing()
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
|
||||
StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix", revision="fp16"
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@@ -76,8 +61,3 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
@@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
}
|
||||
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
self.model_name, **model_kwargs
|
||||
self.model_id_or_path, **model_kwargs
|
||||
).to(device)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
@@ -40,9 +40,6 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
self.model.scheduler = scheduler
|
||||
|
||||
generator = torch.manual_seed(config.sd_seed)
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
mask = mask.astype(np.float32) / 255
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
@@ -66,20 +63,7 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
|
||||
class Kandinsky22(Kandinsky):
|
||||
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
model_name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
AutoPipelineForInpainting.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
)
|
||||
name = "kandinsky2.2"
|
||||
model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
|
||||
@@ -31,10 +31,6 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
|
||||
self.model.enable_attention_slicing()
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# TODO: gpu_id
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||
@@ -68,8 +64,3 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import os
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
@@ -49,23 +46,12 @@ class SD(DiffusionInpaintModel):
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16"
|
||||
if (
|
||||
self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
and use_gpu
|
||||
and fp16
|
||||
)
|
||||
if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
# TODO: gpu_id
|
||||
logger.info("Enable sequential cpu offload")
|
||||
@@ -88,10 +74,6 @@ class SD(DiffusionInpaintModel):
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
output = self.model(
|
||||
@@ -114,17 +96,6 @@ class SD(DiffusionInpaintModel):
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def download(cls):
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
StableDiffusionInpaintPipeline.from_pretrained(cls.model_id_or_path)
|
||||
|
||||
|
||||
class SD15(SD):
|
||||
name = "sd1.5"
|
||||
|
||||
@@ -45,16 +45,9 @@ class SDXL(DiffusionInpaintModel):
|
||||
self.model_id_or_path,
|
||||
revision="main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs.get("enable_xformers", False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
@@ -65,14 +58,6 @@ class SDXL(DiffusionInpaintModel):
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
AutoPipelineForInpainting.from_pretrained(
|
||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@@ -81,10 +66,6 @@ class SDXL(DiffusionInpaintModel):
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
|
||||
output = self.model(
|
||||
@@ -106,8 +87,3 @@ class SDXL(DiffusionInpaintModel):
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
return True
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import List, Dict
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DEFAULT_SD_CONTROLNET_METHOD
|
||||
from lama_cleaner.download import scan_models
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||
@@ -19,16 +18,25 @@ class ModelManager:
|
||||
self.available_models: Dict[str, ModelInfo] = {}
|
||||
self.scan_models()
|
||||
|
||||
self.sd_controlnet = kwargs.get("sd_controlnet", False)
|
||||
self.sd_controlnet_method = kwargs.get(
|
||||
"sd_controlnet_method", DEFAULT_SD_CONTROLNET_METHOD
|
||||
)
|
||||
self.sd_controlnet = False
|
||||
self.sd_controlnet_method = ""
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
def _map_old_name(self, name: str) -> str:
|
||||
for old_name, model_cls in models.items():
|
||||
if name == old_name and hasattr(model_cls, "model_id_or_path"):
|
||||
name = model_cls.model_id_or_path
|
||||
break
|
||||
return name
|
||||
|
||||
@property
|
||||
def current_model(self) -> Dict:
|
||||
name = self._map_old_name(self.name)
|
||||
return self.available_models[name].model_dump()
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
name = self._map_old_name(name)
|
||||
logger.info(f"Loading model: {name}")
|
||||
if name not in self.available_models:
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
@@ -86,6 +94,7 @@ class ModelManager:
|
||||
):
|
||||
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
try:
|
||||
# TODO: enable/disable controlnet without reload model
|
||||
del self.model
|
||||
torch_gc()
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def parse_args():
|
||||
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
|
||||
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
|
||||
parser.add_argument(
|
||||
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
|
||||
"--sd-cpu-textencoder", action="store_true", help=CPU_TEXTENCODER_HELP
|
||||
)
|
||||
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
|
||||
parser.add_argument(
|
||||
@@ -66,16 +66,10 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-xformers", action="store_true", help=ENABLE_XFORMERS_HELP
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", default=DEFAULT_DEVICE, type=str, choices=AVAILABLE_DEVICES
|
||||
)
|
||||
parser.add_argument("--gui", action="store_true", help=GUI_HELP)
|
||||
parser.add_argument(
|
||||
"--no-gui-auto-close", action="store_true", help=NO_GUI_AUTO_CLOSE_HELP
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gui-size",
|
||||
default=[1600, 1000],
|
||||
|
||||
@@ -22,7 +22,7 @@ SEGMENT_ANYTHING_MODELS = {
|
||||
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
"md5": "4b8939a88964f0f4ff5f5b2642c598a6",
|
||||
},
|
||||
"vit_t": {
|
||||
"mobile_sam": {
|
||||
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
|
||||
"md5": "f3c0d8cda613564d499310dab6c812cd",
|
||||
},
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
import cv2
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import RealESRGANModelName
|
||||
from lama_cleaner.const import RealESRGANModel
|
||||
from lama_cleaner.helper import download_model
|
||||
from lama_cleaner.plugins.base_plugin import BasePlugin
|
||||
|
||||
@@ -18,7 +18,7 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
REAL_ESRGAN_MODELS = {
|
||||
RealESRGANModelName.realesr_general_x4v3: {
|
||||
RealESRGANModel.realesr_general_x4v3: {
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
"scale": 4,
|
||||
"model": lambda: SRVGGNetCompact(
|
||||
@@ -31,7 +31,7 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
),
|
||||
"model_md5": "91a7644643c884ee00737db24e478156",
|
||||
},
|
||||
RealESRGANModelName.RealESRGAN_x4plus: {
|
||||
RealESRGANModel.RealESRGAN_x4plus: {
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
"scale": 4,
|
||||
"model": lambda: RRDBNet(
|
||||
@@ -44,7 +44,7 @@ class RealESRGANUpscaler(BasePlugin):
|
||||
),
|
||||
"model_md5": "99ec365d4afad750833258a1a24f44ca",
|
||||
},
|
||||
RealESRGANModelName.RealESRGAN_x4plus_anime_6B: {
|
||||
RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
|
||||
"url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"scale": 4,
|
||||
"model": lambda: RRDBNet(
|
||||
|
||||
@@ -109,7 +109,7 @@ sam_model_registry = {
|
||||
"vit_h": build_sam,
|
||||
"vit_l": build_sam_vit_l,
|
||||
"vit_b": build_sam_vit_b,
|
||||
"vit_t": build_sam_vit_t,
|
||||
"mobile_sam": build_sam_vit_t,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
# https://github.com/huggingface/huggingface_hub/blob/5a12851f54bf614be39614034ed3a9031922d297/src/huggingface_hub/utils/_runtime.py
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
from loguru import logger
|
||||
from rich import print
|
||||
from typing import Dict, Any
|
||||
|
||||
from lama_cleaner.const import Device
|
||||
|
||||
_PY_VERSION: str = sys.version.split()[0].rstrip("+")
|
||||
|
||||
if packaging.version.Version(_PY_VERSION) < packaging.version.Version("3.8.0"):
|
||||
@@ -21,7 +27,6 @@ _CANDIDATES = [
|
||||
"diffusers",
|
||||
"transformers",
|
||||
"opencv-python",
|
||||
"xformers",
|
||||
"accelerate",
|
||||
"lama-cleaner",
|
||||
"rembg",
|
||||
@@ -38,7 +43,7 @@ for name in _CANDIDATES:
|
||||
|
||||
|
||||
def dump_environment_info() -> Dict[str, str]:
|
||||
"""Dump information about the machine to help debugging issues. """
|
||||
"""Dump information about the machine to help debugging issues."""
|
||||
|
||||
# Generic machine info
|
||||
info: Dict[str, Any] = {
|
||||
@@ -48,3 +53,34 @@ def dump_environment_info() -> Dict[str, str]:
|
||||
info.update(_package_versions)
|
||||
print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n")
|
||||
return info
|
||||
|
||||
|
||||
def check_device(device: Device) -> Device:
|
||||
if device == Device.cuda:
|
||||
import platform
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
logger.warning("MacOS does not support cuda, use cpu instead")
|
||||
return Device.cpu
|
||||
else:
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning("CUDA is not available, use cpu instead")
|
||||
return Device.cpu
|
||||
elif device == Device.mps:
|
||||
import torch
|
||||
|
||||
if not torch.backends.mps.is_available():
|
||||
logger.warning("mps is not available, use cpu instead")
|
||||
return Device.cpu
|
||||
return device
|
||||
|
||||
|
||||
def setup_model_dir(model_dir: Path):
|
||||
model_dir = model_dir.expanduser().absolute()
|
||||
os.environ["U2NET_HOME"] = str(model_dir)
|
||||
os.environ["XDG_CACHE_HOME"] = str(model_dir)
|
||||
if not model_dir.exists():
|
||||
logger.info(f"Create model directory: {model_dir}")
|
||||
model_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import traceback
|
||||
|
||||
import typer
|
||||
from typer import Option
|
||||
|
||||
from lama_cleaner.download import cli_download_model, scan_models
|
||||
from lama_cleaner.runtime import setup_model_dir, dump_environment_info, check_device
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
import hashlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
import imghdr
|
||||
import io
|
||||
@@ -20,12 +28,7 @@ import torch
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SD15_MODELS,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
)
|
||||
from lama_cleaner.const import *
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model_manager import ModelManager
|
||||
@@ -39,6 +42,8 @@ from lama_cleaner.plugins import (
|
||||
)
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
typer_app = typer.Typer(pretty_exceptions_show_locals=False, add_completion=False)
|
||||
|
||||
try:
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
@@ -103,23 +108,34 @@ logging.getLogger("werkzeug").addFilter(NoFlaskwebgui())
|
||||
|
||||
app = Flask(__name__, static_folder=os.path.join(BUILD_DIR, "static"))
|
||||
app.config["JSON_AS_ASCII"] = False
|
||||
CORS(app, expose_headers=["Content-Disposition", "X-seed"])
|
||||
CORS(app, expose_headers=["Content-Disposition", "X-seed", "X-Height", "X-Width"])
|
||||
|
||||
sio_logger = logging.getLogger("sio-logger")
|
||||
sio_logger.setLevel(logging.ERROR)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*", async_mode="threading")
|
||||
|
||||
model: ModelManager = None
|
||||
thumb: FileManager = None
|
||||
output_dir: str = None
|
||||
device = None
|
||||
input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
enable_file_manager: bool = False
|
||||
enable_auto_saving: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
@dataclass
|
||||
class GlobalConfig:
|
||||
model_manager: ModelManager = None
|
||||
file_manager: FileManager = None
|
||||
output_dir: Path = None
|
||||
input_image_path: Path = None
|
||||
disable_model_switch: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
|
||||
@property
|
||||
def enable_auto_saving(self) -> bool:
|
||||
return self.output_dir is not None
|
||||
|
||||
@property
|
||||
def enable_file_manager(self) -> bool:
|
||||
return self.file_manager is not None
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
|
||||
def get_image_ext(img_bytes):
|
||||
@@ -135,7 +151,7 @@ def diffuser_callback(i, t, latents):
|
||||
|
||||
@app.route("/save_image", methods=["POST"])
|
||||
def save_image():
|
||||
if output_dir is None:
|
||||
if global_config.output_dir is None:
|
||||
return "--output-dir is None", 500
|
||||
|
||||
input = request.files
|
||||
@@ -143,7 +159,7 @@ def save_image():
|
||||
origin_image_bytes = input["image"].read() # RGB
|
||||
ext = get_image_ext(origin_image_bytes)
|
||||
image, alpha_channel, exif_infos = load_img(origin_image_bytes, return_exif=True)
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
save_path = str(global_config.output_dir / filename)
|
||||
|
||||
if alpha_channel is not None:
|
||||
if alpha_channel.shape[:2] != image.shape[:2]:
|
||||
@@ -157,7 +173,7 @@ def save_image():
|
||||
img_bytes = pil_to_bytes(
|
||||
pil_image,
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
with open(save_path, "wb") as fw:
|
||||
@@ -169,9 +185,11 @@ def save_image():
|
||||
@app.route("/medias/<tab>")
|
||||
def medias(tab):
|
||||
if tab == "image":
|
||||
response = make_response(jsonify(thumb.media_names), 200)
|
||||
response = make_response(jsonify(global_config.file_manager.media_names), 200)
|
||||
else:
|
||||
response = make_response(jsonify(thumb.output_media_names), 200)
|
||||
response = make_response(
|
||||
jsonify(global_config.file_manager.output_media_names), 200
|
||||
)
|
||||
# response.last_modified = thumb.modified_time[tab]
|
||||
# response.cache_control.no_cache = True
|
||||
# response.cache_control.max_age = 0
|
||||
@@ -182,8 +200,8 @@ def medias(tab):
|
||||
@app.route("/media/<tab>/<filename>")
|
||||
def media_file(tab, filename):
|
||||
if tab == "image":
|
||||
return send_from_directory(thumb.root_directory, filename)
|
||||
return send_from_directory(thumb.output_dir, filename)
|
||||
return send_from_directory(global_config.file_manager.root_directory, filename)
|
||||
return send_from_directory(global_config.file_manager.output_dir, filename)
|
||||
|
||||
|
||||
@app.route("/media_thumbnail/<tab>/<filename>")
|
||||
@@ -198,10 +216,10 @@ def media_thumbnail_file(tab, filename):
|
||||
if height:
|
||||
height = int(float(height))
|
||||
|
||||
directory = thumb.root_directory
|
||||
directory = global_config.file_manager.root_directory
|
||||
if tab == "output":
|
||||
directory = thumb.output_dir
|
||||
thumb_filename, (width, height) = thumb.get_thumbnail(
|
||||
directory = global_config.file_manager.output_dir
|
||||
thumb_filename, (width, height) = global_config.file_manager.get_thumbnail(
|
||||
directory, filename, width, height
|
||||
)
|
||||
thumb_filepath = f"{app.config['THUMBNAIL_MEDIA_THUMBNAIL_ROOT']}{thumb_filename}"
|
||||
@@ -257,13 +275,11 @@ def process():
|
||||
croper_y=form["croperY"],
|
||||
croper_height=form["croperHeight"],
|
||||
croper_width=form["croperWidth"],
|
||||
|
||||
use_extender=form["useExtender"],
|
||||
extender_x=form["extenderX"],
|
||||
extender_y=form["extenderY"],
|
||||
extender_height=form["extenderHeight"],
|
||||
extender_width=form["extenderWidth"],
|
||||
|
||||
sd_scale=form["sdScale"],
|
||||
sd_mask_blur=form["sdMaskBlur"],
|
||||
sd_strength=form["sdStrength"],
|
||||
@@ -294,7 +310,7 @@ def process():
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
res_np_img = model(image, mask, config)
|
||||
res_np_img = global_config.model_manager(image, mask, config)
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory. " in str(e):
|
||||
# NOTE: the string may change?
|
||||
@@ -322,7 +338,7 @@ def process():
|
||||
pil_to_bytes(
|
||||
Image.fromarray(res_np_img),
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
)
|
||||
@@ -345,7 +361,7 @@ def run_plugin():
|
||||
form = request.form
|
||||
files = request.files
|
||||
name = form["name"]
|
||||
if name not in plugins:
|
||||
if name not in global_config.plugins:
|
||||
return "Plugin not found", 500
|
||||
|
||||
origin_image_bytes = files["image"].read() # RGB
|
||||
@@ -359,7 +375,7 @@ def run_plugin():
|
||||
if name == InteractiveSeg.name:
|
||||
img_md5 = hashlib.md5(origin_image_bytes).hexdigest()
|
||||
form["img_md5"] = img_md5
|
||||
bgr_res = plugins[name](rgb_np_img, files, form)
|
||||
bgr_res = global_config.plugins[name](rgb_np_img, files, form)
|
||||
except RuntimeError as e:
|
||||
torch.cuda.empty_cache()
|
||||
if "CUDA out of memory. " in str(e):
|
||||
@@ -401,7 +417,7 @@ def run_plugin():
|
||||
pil_to_bytes(
|
||||
Image.fromarray(rgb_res),
|
||||
ext,
|
||||
quality=image_quality,
|
||||
quality=global_config.image_quality,
|
||||
exif_infos=exif_infos,
|
||||
)
|
||||
),
|
||||
@@ -414,41 +430,40 @@ def run_plugin():
|
||||
@app.route("/server_config", methods=["GET"])
|
||||
def get_server_config():
|
||||
return {
|
||||
"plugins": list(plugins.keys()),
|
||||
"enableFileManager": enable_file_manager,
|
||||
"enableAutoSaving": enable_auto_saving,
|
||||
"enableControlnet": model.sd_controlnet,
|
||||
"controlnetMethod": model.sd_controlnet_method,
|
||||
"disableModelSwitch": is_disable_model_switch,
|
||||
"plugins": list(global_config.plugins.keys()),
|
||||
"enableFileManager": global_config.enable_file_manager,
|
||||
"enableAutoSaving": global_config.enable_auto_saving,
|
||||
"enableControlnet": global_config.model_manager.sd_controlnet,
|
||||
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
|
||||
"disableModelSwitch": global_config.disable_model_switch,
|
||||
"isDesktop": global_config.is_desktop,
|
||||
}, 200
|
||||
|
||||
|
||||
@app.route("/models", methods=["GET"])
|
||||
def get_models():
|
||||
return [it.model_dump() for it in model.scan_models()]
|
||||
return [it.model_dump() for it in global_config.model_manager.scan_models()]
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.available_models[model.name].model_dump(), 200
|
||||
|
||||
|
||||
@app.route("/is_desktop")
|
||||
def get_is_desktop():
|
||||
return str(is_desktop), 200
|
||||
return (
|
||||
global_config.model_manager.current_model,
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/model", methods=["POST"])
|
||||
def switch_model():
|
||||
if is_disable_model_switch:
|
||||
if global_config.disable_model_switch:
|
||||
return "Switch model is disabled", 400
|
||||
|
||||
new_name = request.form.get("name")
|
||||
if new_name == model.name:
|
||||
if new_name == global_config.model_manager.name:
|
||||
return "Same model", 200
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
global_config.model_manager.switch(new_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_message = f"{type(e).__name__} - {str(e)}"
|
||||
@@ -464,160 +479,230 @@ def index():
|
||||
|
||||
@app.route("/inputimage")
|
||||
def get_cli_input_image():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
if global_config.input_image_path:
|
||||
with open(global_config.input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
return send_file(
|
||||
input_image_path,
|
||||
global_config.input_image_path,
|
||||
as_attachment=True,
|
||||
download_name=Path(input_image_path).name,
|
||||
download_name=Path(global_config.input_image_path).name,
|
||||
mimetype=f"image/{get_image_ext(image_in_bytes)}",
|
||||
)
|
||||
else:
|
||||
return "No Input Image"
|
||||
|
||||
|
||||
def build_plugins(args):
|
||||
global plugins
|
||||
if args.enable_interactive_seg:
|
||||
def build_plugins(
|
||||
enable_interactive_seg: bool,
|
||||
interactive_seg_model: InteractiveSegModel,
|
||||
interactive_seg_device: Device,
|
||||
enable_remove_bg: bool,
|
||||
enable_anime_seg: bool,
|
||||
enable_realesrgan: bool,
|
||||
realesrgan_device: Device,
|
||||
realesrgan_model: str,
|
||||
enable_gfpgan: bool,
|
||||
gfpgan_device: Device,
|
||||
enable_restoreformer: bool,
|
||||
restoreformer_device: Device,
|
||||
no_half: bool,
|
||||
):
|
||||
if enable_interactive_seg:
|
||||
logger.info(f"Initialize {InteractiveSeg.name} plugin")
|
||||
plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
args.interactive_seg_model, args.interactive_seg_device
|
||||
global_config.plugins[InteractiveSeg.name] = InteractiveSeg(
|
||||
interactive_seg_model, interactive_seg_device
|
||||
)
|
||||
|
||||
if args.enable_remove_bg:
|
||||
if enable_remove_bg:
|
||||
logger.info(f"Initialize {RemoveBG.name} plugin")
|
||||
plugins[RemoveBG.name] = RemoveBG()
|
||||
global_config.plugins[RemoveBG.name] = RemoveBG()
|
||||
|
||||
if args.enable_anime_seg:
|
||||
if enable_anime_seg:
|
||||
logger.info(f"Initialize {AnimeSeg.name} plugin")
|
||||
plugins[AnimeSeg.name] = AnimeSeg()
|
||||
global_config.plugins[AnimeSeg.name] = AnimeSeg()
|
||||
|
||||
if args.enable_realesrgan:
|
||||
if enable_realesrgan:
|
||||
logger.info(
|
||||
f"Initialize {RealESRGANUpscaler.name} plugin: {args.realesrgan_model}, {args.realesrgan_device}"
|
||||
f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
|
||||
)
|
||||
plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||
args.realesrgan_model,
|
||||
args.realesrgan_device,
|
||||
no_half=args.realesrgan_no_half,
|
||||
global_config.plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
|
||||
realesrgan_model,
|
||||
realesrgan_device,
|
||||
no_half=no_half,
|
||||
)
|
||||
|
||||
if args.enable_gfpgan:
|
||||
if enable_gfpgan:
|
||||
logger.info(f"Initialize {GFPGANPlugin.name} plugin")
|
||||
if args.enable_realesrgan:
|
||||
if enable_realesrgan:
|
||||
logger.info("Use realesrgan as GFPGAN background upscaler")
|
||||
else:
|
||||
logger.info(
|
||||
f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
|
||||
)
|
||||
plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
args.gfpgan_device, upscaler=plugins.get(RealESRGANUpscaler.name, None)
|
||||
global_config.plugins[GFPGANPlugin.name] = GFPGANPlugin(
|
||||
gfpgan_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
if args.enable_restoreformer:
|
||||
if enable_restoreformer:
|
||||
logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
|
||||
plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
args.restoreformer_device,
|
||||
upscaler=plugins.get(RealESRGANUpscaler.name, None),
|
||||
global_config.plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
|
||||
restoreformer_device,
|
||||
upscaler=global_config.plugins.get(RealESRGANUpscaler.name, None),
|
||||
)
|
||||
|
||||
def main(args):
|
||||
global model
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
global enable_file_manager
|
||||
global is_desktop
|
||||
global thumb
|
||||
global output_dir
|
||||
global image_quality
|
||||
global enable_auto_saving
|
||||
|
||||
build_plugins(args)
|
||||
@typer_app.command(help="Install all plugins dependencies")
|
||||
def install_plugins_packages():
|
||||
from lama_cleaner.installer import install_plugins_package
|
||||
|
||||
image_quality = args.quality
|
||||
output_dir = args.output_dir
|
||||
install_plugins_package()
|
||||
|
||||
|
||||
@typer_app.command(help="Download SD/SDXL normal/inpainting model from HuggingFace")
|
||||
def download(
|
||||
model: str = Option(
|
||||
..., help="Model id on HuggingFace e.g: runwayml/stable-diffusion-inpainting"
|
||||
),
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
cli_download_model(model, model_dir)
|
||||
|
||||
|
||||
@typer_app.command(help="List downloaded models")
|
||||
def list_model(
|
||||
model_dir: Path = Option(DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, file_okay=False),
|
||||
):
|
||||
setup_model_dir(model_dir)
|
||||
scanned_models = scan_models()
|
||||
for it in scanned_models:
|
||||
print(it.name)
|
||||
|
||||
|
||||
@typer_app.command(help="Start lama cleaner server")
|
||||
def start(
|
||||
host: str = Option("127.0.0.1"),
|
||||
port: int = Option(8080),
|
||||
model: str = Option(
|
||||
DEFAULT_MODEL,
|
||||
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}]. "
|
||||
f"You can use download command to download other SD/SDXL normal/inpainting models on huggingface",
|
||||
),
|
||||
model_dir: Path = Option(
|
||||
DEFAULT_MODEL_DIR, help=MODEL_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
no_half: bool = Option(False, help=NO_HALF_HELP),
|
||||
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
||||
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
||||
cpu_textencoder: bool = Option(False, help=CPU_TEXTENCODER_HELP),
|
||||
local_files_only: bool = Option(False, help=LOCAL_FILES_ONLY_HELP),
|
||||
device: Device = Option(Device.cpu),
|
||||
gui: bool = Option(False, help=GUI_HELP),
|
||||
disable_model_switch: bool = Option(False),
|
||||
input: Path = Option(None, help=INPUT_HELP),
|
||||
output_dir: Path = Option(
|
||||
None, help=OUTPUT_DIR_HELP, dir_okay=True, file_okay=False
|
||||
),
|
||||
quality: int = Option(95, help=QUALITY_HELP),
|
||||
enable_interactive_seg: bool = Option(False, help=INTERACTIVE_SEG_HELP),
|
||||
interactive_seg_model: InteractiveSegModel = Option(
|
||||
InteractiveSegModel.vit_b, help=INTERACTIVE_SEG_MODEL_HELP
|
||||
),
|
||||
interactive_seg_device: Device = Option(Device.cpu),
|
||||
enable_remove_bg: bool = Option(False, help=REMOVE_BG_HELP),
|
||||
enable_anime_seg: bool = Option(False, help=ANIMESEG_HELP),
|
||||
enable_realesrgan: bool = Option(False),
|
||||
realesrgan_device: Device = Option(Device.cpu),
|
||||
realesrgan_model: str = Option(RealESRGANModel.realesr_general_x4v3),
|
||||
enable_gfpgan: bool = Option(False),
|
||||
gfpgan_device: Device = Option(Device.cpu),
|
||||
enable_restoreformer: bool = Option(False),
|
||||
restoreformer_device: Device = Option(Device.cpu),
|
||||
):
|
||||
global global_config
|
||||
dump_environment_info()
|
||||
|
||||
if input:
|
||||
if not input.exists():
|
||||
logger.error(f"invalid --input: {input} not exists")
|
||||
exit()
|
||||
if input.is_dir():
|
||||
logger.info(f"Initialize file manager")
|
||||
file_manager = FileManager(app)
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
file_manager.output_dir = output_dir
|
||||
else:
|
||||
global_config.input_image_path = input
|
||||
|
||||
device = check_device(device)
|
||||
setup_model_dir(model_dir)
|
||||
|
||||
if local_files_only:
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
if model not in AVAILABLE_MODELS:
|
||||
scanned_models = scan_models()
|
||||
if model not in [it.name for it in scanned_models]:
|
||||
logger.error(
|
||||
f"invalid --model: {model} not exists. Available models: {AVAILABLE_MODELS} or {[it.name for it in scanned_models]}"
|
||||
)
|
||||
exit()
|
||||
|
||||
global_config.image_quality = quality
|
||||
global_config.disable_model_switch = disable_model_switch
|
||||
global_config.is_desktop = gui
|
||||
build_plugins(
|
||||
enable_interactive_seg,
|
||||
interactive_seg_model,
|
||||
interactive_seg_device,
|
||||
enable_remove_bg,
|
||||
enable_anime_seg,
|
||||
enable_realesrgan,
|
||||
realesrgan_device,
|
||||
realesrgan_model,
|
||||
enable_gfpgan,
|
||||
gfpgan_device,
|
||||
enable_restoreformer,
|
||||
restoreformer_device,
|
||||
no_half,
|
||||
)
|
||||
if output_dir:
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
logger.info(f"Output dir: {output_dir}")
|
||||
enable_auto_saving = True
|
||||
output_dir = output_dir.expanduser().absolute()
|
||||
logger.info(f"Image will auto save to output dir: {output_dir}")
|
||||
global_config.output_dir = output_dir
|
||||
|
||||
device = torch.device(args.device)
|
||||
is_disable_model_switch = args.disable_model_switch
|
||||
is_desktop = args.gui
|
||||
if is_disable_model_switch:
|
||||
logger.info(
|
||||
f"Start with --disable-model-switch, model switch on frontend is disable"
|
||||
)
|
||||
|
||||
if args.input and os.path.isdir(args.input):
|
||||
logger.info(f"Initialize file manager")
|
||||
thumb = FileManager(app)
|
||||
enable_file_manager = True
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
thumb.output_dir = Path(output_dir)
|
||||
# thumb.start()
|
||||
# try:
|
||||
# while True:
|
||||
# time.sleep(1)
|
||||
# finally:
|
||||
# thumb.image_dir_observer.stop()
|
||||
# thumb.image_dir_observer.join()
|
||||
# thumb.output_dir_observer.stop()
|
||||
# thumb.output_dir_observer.join()
|
||||
|
||||
else:
|
||||
input_image_path = args.input
|
||||
|
||||
# 为了兼容性
|
||||
model_name_map = {
|
||||
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
||||
"anything4": "Sanster/anything-4.0-inpainting",
|
||||
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
"sd2": "stabilityai/stable-diffusion-2-inpainting",
|
||||
"sdxl": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
"kandinsky2.2": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
"paint_by_example": "Fantasy-Studio/Paint-by-Example",
|
||||
"instruct_pix2pix": "timbrooks/instruct-pix2pix",
|
||||
}
|
||||
|
||||
model = ModelManager(
|
||||
name=model_name_map.get(args.model, args.model),
|
||||
sd_controlnet=args.sd_controlnet,
|
||||
sd_controlnet_method=args.sd_controlnet_method,
|
||||
device=device,
|
||||
no_half=args.no_half,
|
||||
hf_access_token=args.hf_access_token,
|
||||
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||
cpu_offload=args.cpu_offload,
|
||||
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
||||
global_config.model_manager = ModelManager(
|
||||
name=model,
|
||||
device=torch.device(device),
|
||||
no_half=no_half,
|
||||
disable_nsfw=disable_nsfw_checker,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
cpu_offload=cpu_offload,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
if args.gui:
|
||||
app_width, app_height = args.gui_size
|
||||
if gui:
|
||||
from flaskwebgui import FlaskUI
|
||||
|
||||
ui = FlaskUI(
|
||||
app,
|
||||
socketio=socketio,
|
||||
width=app_width,
|
||||
height=app_height,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
close_server_on_exit=not args.no_gui_auto_close,
|
||||
width=1200,
|
||||
height=800,
|
||||
host=host,
|
||||
port=port,
|
||||
close_server_on_exit=True,
|
||||
idle_interval=60,
|
||||
)
|
||||
ui.run()
|
||||
else:
|
||||
socketio.run(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
host=host,
|
||||
port=port,
|
||||
allow_unsafe_werkzeug=True,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,6 @@ def test_runway_sd_1_5(
|
||||
name=model_name,
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
sd_controlnet_method=sd_controlnet_method,
|
||||
@@ -87,7 +86,6 @@ def test_local_file_path(sd_device, sampler):
|
||||
name=model_name,
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
@@ -125,7 +123,6 @@ def test_local_file_path_controlnet_native_inpainting(sd_device, sampler):
|
||||
name=model_name,
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
@@ -166,7 +163,6 @@ def test_controlnet_switch(sd_device, sampler):
|
||||
name=model_name,
|
||||
sd_controlnet=True,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
|
||||
@@ -21,7 +21,6 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload):
|
||||
model = ModelManager(
|
||||
name=model_name,
|
||||
device=torch.device(device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=cpu_offload,
|
||||
@@ -52,7 +51,6 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload):
|
||||
model = ModelManager(
|
||||
name=model_name,
|
||||
device=torch.device(device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=cpu_offload,
|
||||
|
||||
@@ -17,11 +17,9 @@ def test_load_model():
|
||||
name=m,
|
||||
device="cpu",
|
||||
no_half=False,
|
||||
hf_access_token="",
|
||||
disable_nsfw=False,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=True,
|
||||
enable_xformers=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,11 +16,9 @@ def test_model_switch():
|
||||
sd_controlnet=True,
|
||||
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||
device=torch.device("mps"),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
enable_xformers=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
@@ -34,11 +32,9 @@ def test_controlnet_switch_onoff(caplog):
|
||||
sd_controlnet=True,
|
||||
sd_controlnet_method="lllyasviel/control_v11p_sd15_canny",
|
||||
device=torch.device("mps"),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
enable_xformers=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
@@ -61,11 +57,9 @@ def test_controlnet_switch_method(caplog):
|
||||
sd_controlnet=True,
|
||||
sd_controlnet_method=old_method,
|
||||
device=torch.device("mps"),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
cpu_offload=False,
|
||||
enable_xformers=False,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ def test_outpainting(name, sd_device, rect):
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
@@ -86,7 +85,6 @@ def test_kandinsky_outpainting(name, sd_device, rect):
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
|
||||
@@ -38,7 +38,6 @@ def test_runway_sd_1_5_all_samplers(
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
@@ -69,7 +68,6 @@ def test_runway_sd_lcm_lora(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
@@ -102,7 +100,6 @@ def test_runway_sd_freeu(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
@@ -136,7 +133,6 @@ def test_runway_sd_sd_strength(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
@@ -165,7 +161,6 @@ def test_runway_norm_sd_model(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
@@ -192,7 +187,6 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-inpainting",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True,
|
||||
@@ -229,7 +223,6 @@ def test_local_file_path(sd_device, sampler, name):
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=False,
|
||||
|
||||
@@ -29,7 +29,6 @@ def test_sdxl(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
@@ -70,7 +69,6 @@ def test_sdxl_lcm_lora_and_freeu(sd_device, strategy, sampler):
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback,
|
||||
@@ -131,7 +129,6 @@ def test_sdxl_outpainting(sd_device, rect):
|
||||
model = ModelManager(
|
||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ def save_config(
|
||||
cpu_offload,
|
||||
disable_nsfw,
|
||||
sd_cpu_textencoder,
|
||||
enable_xformers,
|
||||
local_files_only,
|
||||
model_dir,
|
||||
input,
|
||||
@@ -102,9 +101,6 @@ def main(config_file: str):
|
||||
|
||||
with gr.Column():
|
||||
gui = gr.Checkbox(init_config.gui, label=f"{GUI_HELP}")
|
||||
no_gui_auto_close = gr.Checkbox(
|
||||
init_config.no_gui_auto_close, label=f"{NO_GUI_AUTO_CLOSE_HELP}"
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
model_dir = gr.Textbox(
|
||||
@@ -193,14 +189,11 @@ def main(config_file: str):
|
||||
init_config.cpu_offload, label=f"{CPU_OFFLOAD_HELP}"
|
||||
)
|
||||
sd_cpu_textencoder = gr.Checkbox(
|
||||
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
|
||||
init_config.sd_cpu_textencoder, label=f"{CPU_TEXTENCODER_HELP}"
|
||||
)
|
||||
disable_nsfw = gr.Checkbox(
|
||||
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
|
||||
)
|
||||
enable_xformers = gr.Checkbox(
|
||||
init_config.enable_xformers, label=f"{ENABLE_XFORMERS_HELP}"
|
||||
)
|
||||
local_files_only = gr.Checkbox(
|
||||
init_config.local_files_only, label=f"{LOCAL_FILES_ONLY_HELP}"
|
||||
)
|
||||
@@ -221,7 +214,6 @@ def main(config_file: str):
|
||||
cpu_offload,
|
||||
disable_nsfw,
|
||||
sd_cpu_textencoder,
|
||||
enable_xformers,
|
||||
local_files_only,
|
||||
model_dir,
|
||||
input,
|
||||
|
||||
Reference in New Issue
Block a user