wip
This commit is contained in:
@@ -4,16 +4,14 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
MPS_SUPPORT_MODELS = [
|
||||
"instruct_pix2pix",
|
||||
"sd1.5",
|
||||
"anything4",
|
||||
"realisticVision1.4",
|
||||
"sd2",
|
||||
"paint_by_example",
|
||||
"controlnet",
|
||||
"kandinsky2.2",
|
||||
"sdxl",
|
||||
MPS_UNSUPPORT_MODELS = [
|
||||
"lama",
|
||||
"ldm",
|
||||
"zits",
|
||||
"mat",
|
||||
"fcf",
|
||||
"cv2",
|
||||
"manga",
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "lama"
|
||||
@@ -36,18 +34,13 @@ AVAILABLE_MODELS = [
|
||||
"sdxl",
|
||||
]
|
||||
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||
MODELS_SUPPORT_FREEU = SD15_MODELS + ["sd2", "sdxl", "instruct_pix2pix"]
|
||||
MODELS_SUPPORT_LCM_LORA = SD15_MODELS + ["sdxl"]
|
||||
|
||||
FREEU_DEFAULT_CONFIGS = {
|
||||
"sd2": dict(s1=0.9, s2=0.2, b1=1.1, b2=1.2),
|
||||
"sdxl": dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
|
||||
"sd1.5": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
"anything4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
"realisticVision1.4": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
"instruct_pix2pix": dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
}
|
||||
|
||||
DIFFUSERS_MODEL_FP16_REVERSION = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"Sanster/anything-4.0-inpainting",
|
||||
"Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
"timbrooks/instruct-pix2pix",
|
||||
]
|
||||
|
||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||
DEFAULT_DEVICE = "cuda"
|
||||
@@ -70,14 +63,29 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
||||
"""
|
||||
|
||||
SD_CONTROLNET_HELP = """
|
||||
Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui.
|
||||
Run Stable Diffusion normal or inpainting model with ControlNet.
|
||||
"""
|
||||
DEFAULT_CONTROLNET_METHOD = "control_v11p_sd15_canny"
|
||||
DEFAULT_SD_CONTROLNET_METHOD = "thibaud/controlnet-sd21-openpose-diffusers"
|
||||
SD_CONTROLNET_CHOICES = [
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_openpose",
|
||||
"control_v11p_sd15_inpaint",
|
||||
"control_v11f1p_sd15_depth",
|
||||
"lllyasviel/control_v11p_sd15_canny",
|
||||
# "lllyasviel/control_v11p_sd15_seg",
|
||||
"lllyasviel/control_v11p_sd15_openpose",
|
||||
"lllyasviel/control_v11p_sd15_inpaint",
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
]
|
||||
|
||||
DEFAULT_SD2_CONTROLNET_METHOD = "thibaud/controlnet-sd21-canny-diffusers"
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
"thibaud/controlnet-sd21-openpose-diffusers",
|
||||
]
|
||||
|
||||
DEFAULT_SDXL_CONTROLNET_METHOD = "diffusers/controlnet-canny-sdxl-1.0"
|
||||
SDXL_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-openpose-sdxl-1.0",
|
||||
"diffusers/controlnet-canny-sdxl-1.0",
|
||||
"diffusers/controlnet-depth-sdxl-1.0",
|
||||
]
|
||||
|
||||
SD_LOCAL_MODEL_HELP = """
|
||||
@@ -152,7 +160,7 @@ class Config(BaseModel):
|
||||
model: str = DEFAULT_MODEL
|
||||
sd_local_model_path: str = None
|
||||
sd_controlnet: bool = False
|
||||
sd_controlnet_method: str = DEFAULT_CONTROLNET_METHOD
|
||||
sd_controlnet_method: str = DEFAULT_SD_CONTROLNET_METHOD
|
||||
device: str = DEFAULT_DEVICE
|
||||
gui: bool = False
|
||||
no_gui_auto_close: bool = False
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def folder_name_to_show_name(name: str) -> str:
|
||||
return name.replace("models--", "").replace("--", "/")
|
||||
|
||||
|
||||
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
|
||||
cache_dir = Path(cache_dir)
|
||||
res = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if data["_class_name"] in class_name:
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name not in res:
|
||||
res.append(name)
|
||||
return res
|
||||
|
||||
|
||||
def scan_models(cache_dir) -> Dict[str, List[str]]:
|
||||
return {
|
||||
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
|
||||
"sd_inpaint": _scan_models(
|
||||
cache_dir,
|
||||
[
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
],
|
||||
),
|
||||
"other": _scan_models(
|
||||
cache_dir,
|
||||
[
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
],
|
||||
),
|
||||
}
|
||||
@@ -1,8 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.schema import (
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SD_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_CLASS_NAME,
|
||||
)
|
||||
|
||||
|
||||
def cli_download_model(model: str, model_dir: str):
|
||||
if os.path.isfile(model_dir):
|
||||
@@ -14,7 +26,7 @@ def cli_download_model(model: str, model_dir: str):
|
||||
|
||||
os.environ["XDG_CACHE_HOME"] = model_dir
|
||||
|
||||
from lama_cleaner.model_manager import models
|
||||
from lama_cleaner.model import models
|
||||
|
||||
if model in models:
|
||||
logger.info(f"Downloading {model}...")
|
||||
@@ -22,3 +34,127 @@ def cli_download_model(model: str, model_dir: str):
|
||||
logger.info(f"Done.")
|
||||
else:
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
downloaded_path = DiffusionPipeline.download(
|
||||
pretrained_model_name=model,
|
||||
revision="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
resume_download=True,
|
||||
)
|
||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||
|
||||
|
||||
def folder_name_to_show_name(name: str) -> str:
|
||||
return name.replace("models--", "").replace("--", "/")
|
||||
|
||||
|
||||
def scan_diffusers_models(
|
||||
cache_dir, class_name: List[str], model_type: ModelType
|
||||
) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
res = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if data["_class_name"] in class_name:
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name not in res:
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=name,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
res = []
|
||||
for it in cache_dir.glob(f"*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
if "inpaint" in str(it).lower():
|
||||
if "sdxl" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
else:
|
||||
if "sdxl" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=str(it.absolute()),
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def scan_inpaint_models() -> List[ModelInfo]:
|
||||
res = []
|
||||
from lama_cleaner.model import models
|
||||
|
||||
for name, m in models.items():
|
||||
if m.is_erase_model:
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=name,
|
||||
model_type=ModelType.INPAINT,
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def scan_models() -> List[ModelInfo]:
|
||||
from diffusers.utils import DIFFUSERS_CACHE
|
||||
|
||||
available_models = []
|
||||
available_models.extend(scan_inpaint_models())
|
||||
available_models.extend(
|
||||
scan_single_file_diffusion_models(os.environ["XDG_CACHE_HOME"])
|
||||
)
|
||||
|
||||
cache_dir = Path(DIFFUSERS_CACHE)
|
||||
diffusers_model_names = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
_class_name = data["_class_name"]
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name in diffusers_model_names:
|
||||
continue
|
||||
|
||||
if _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
elif _class_name in [
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
]:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
else:
|
||||
continue
|
||||
|
||||
diffusers_model_names.append(name)
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
name=name,
|
||||
path=name,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
|
||||
return available_models
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
# from watchdog.events import FileSystemEventHandler
|
||||
# from watchdog.observers import Observer
|
||||
|
||||
@@ -149,6 +150,7 @@ class FileManager:
|
||||
def get_thumbnail(
|
||||
self, directory: Path, original_filename: str, width, height, **options
|
||||
):
|
||||
directory = Path(directory)
|
||||
storage = FilesystemStorageBackend(self.app)
|
||||
crop = options.get("crop", "fit")
|
||||
background = options.get("background")
|
||||
@@ -167,6 +169,7 @@ class FileManager:
|
||||
thumbnail_size = (width, height)
|
||||
|
||||
thumbnail_filename = generate_filename(
|
||||
directory,
|
||||
original_filename,
|
||||
aspect_to_string(thumbnail_size),
|
||||
crop,
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
# Copy from: https://github.com/silentsokolov/flask-thumbnails/blob/master/flask_thumbnails/utils.py
|
||||
import importlib
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
def generate_filename(original_filename, *options):
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
def generate_filename(directory: Path, original_filename, *options) -> str:
|
||||
text = str(directory.absolute()) + original_filename
|
||||
for v in options:
|
||||
if v:
|
||||
name += "_%s" % v
|
||||
name += ext
|
||||
|
||||
return name
|
||||
text += "%s" % v
|
||||
md5_hash = hashlib.md5()
|
||||
md5_hash.update(text.encode("utf-8"))
|
||||
return md5_hash.hexdigest() + ".jpg"
|
||||
|
||||
|
||||
def parse_size(size):
|
||||
@@ -48,7 +46,7 @@ def aspect_to_string(size):
|
||||
return "x".join(map(str, size))
|
||||
|
||||
|
||||
IMG_SUFFIX = {'.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG'}
|
||||
IMG_SUFFIX = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
||||
|
||||
|
||||
def glob_img(p: Union[Path, str], recursive: bool = False):
|
||||
|
||||
@@ -8,7 +8,7 @@ import cv2
|
||||
from PIL import Image, ImageOps, PngImagePlugin
|
||||
import numpy as np
|
||||
import torch
|
||||
from lama_cleaner.const import MPS_SUPPORT_MODELS
|
||||
from lama_cleaner.const import MPS_UNSUPPORT_MODELS
|
||||
from loguru import logger
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
import hashlib
|
||||
@@ -23,7 +23,7 @@ def md5sum(filename):
|
||||
|
||||
|
||||
def switch_mps_device(model_name, device):
|
||||
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
|
||||
if model_name in MPS_UNSUPPORT_MODELS and str(device) == "mps":
|
||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from .controlnet import ControlNet
|
||||
from .fcf import FcF
|
||||
from .instruct_pix2pix import InstructPix2Pix
|
||||
from .kandinsky import Kandinsky22
|
||||
from .lama import LaMa
|
||||
from .ldm import LDM
|
||||
from .manga import Manga
|
||||
from .mat import MAT
|
||||
from .mi_gan import MIGAN
|
||||
from .opencv2 import OpenCV2
|
||||
from .paint_by_example import PaintByExample
|
||||
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
||||
from .sdxl import SDXL
|
||||
from .zits import ZITS
|
||||
|
||||
models = {
|
||||
LaMa.name: LaMa,
|
||||
LDM.name: LDM,
|
||||
ZITS.name: ZITS,
|
||||
MAT.name: MAT,
|
||||
FcF.name: FcF,
|
||||
OpenCV2.name: OpenCV2,
|
||||
Manga.name: Manga,
|
||||
MIGAN.name: MIGAN,
|
||||
SD15.name: SD15,
|
||||
Anything4.name: Anything4,
|
||||
RealisticVision14.name: RealisticVision14,
|
||||
SD2.name: SD2,
|
||||
PaintByExample.name: PaintByExample,
|
||||
InstructPix2Pix.name: InstructPix2Pix,
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ from lama_cleaner.helper import (
|
||||
pad_img_to_modulo,
|
||||
switch_mps_device,
|
||||
)
|
||||
from lama_cleaner.model.g_diffuser_bot import expand_image, np_img_grey_to_rgb
|
||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||
|
||||
@@ -22,6 +22,7 @@ class InpaintModel:
|
||||
min_size: Optional[int] = None
|
||||
pad_mod = 8
|
||||
pad_to_square = False
|
||||
is_erase_model = False
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
@@ -264,6 +265,12 @@ class InpaintModel:
|
||||
|
||||
|
||||
class DiffusionInpaintModel(InpaintModel):
|
||||
def __init__(self, device, **kwargs):
|
||||
if kwargs.get("model_id_or_path"):
|
||||
# 用于自定义 diffusers 模型
|
||||
self.model_id_or_path = kwargs["model_id_or_path"]
|
||||
super().__init__(device, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image, mask, config: Config):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import gc
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -7,107 +5,26 @@ import torch
|
||||
from diffusers import ControlNetModel
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||
from lama_cleaner.schema import Config
|
||||
from lama_cleaner.model.helper.controlnet_preprocess import (
|
||||
make_canny_control_image,
|
||||
make_openpose_control_image,
|
||||
make_depth_control_image,
|
||||
make_inpaint_control_image,
|
||||
)
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
del text_encoder
|
||||
torch_gc()
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
return [
|
||||
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||
.to(input_device)
|
||||
.to(self.torch_dtype)
|
||||
]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.torch_dtype
|
||||
|
||||
|
||||
NAMES_MAP = {
|
||||
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
||||
"anything4": "Sanster/anything-4.0-inpainting",
|
||||
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
# 为了兼容性
|
||||
controlnet_name_map = {
|
||||
"control_v11p_sd15_canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"control_v11p_sd15_openpose": "lllyasviel/control_v11p_sd15_openpose",
|
||||
"control_v11p_sd15_inpaint": "lllyasviel/control_v11p_sd15_inpaint",
|
||||
"control_v11f1p_sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
}
|
||||
|
||||
NATIVE_NAMES_MAP = {
|
||||
"sd1.5": "runwayml/stable-diffusion-v1-5",
|
||||
"anything4": "andite/anything-v4.0",
|
||||
"realisticVision1.4": "SG161222/Realistic_Vision_V1.4",
|
||||
}
|
||||
|
||||
|
||||
def make_inpaint_condition(image, image_mask):
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
"""
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel
|
||||
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
||||
|
||||
|
||||
def load_from_local_model(
|
||||
local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint
|
||||
):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline")
|
||||
|
||||
try:
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
local_model_path,
|
||||
num_in_channels=4 if is_native_control_inpaint else 9,
|
||||
from_safetensors=local_model_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False,
|
||||
)
|
||||
except Exception as e:
|
||||
err_msg = str(e)
|
||||
logger.exception(e)
|
||||
if is_native_control_inpaint and "[320, 9, 3, 3]" in err_msg:
|
||||
logger.error(
|
||||
"control_v11p_sd15_inpaint method requires normal SD model, not inpainting SD model"
|
||||
)
|
||||
if not is_native_control_inpaint and "[320, 4, 3, 3]" in err_msg:
|
||||
logger.error(
|
||||
f"{controlnet.config['_name_or_path']} method requires inpainting SD model, "
|
||||
f"you can convert any SD model to inpainting model in AUTO1111: \n"
|
||||
f"https://www.reddit.com/r/StableDiffusion/comments/zyi24j/how_to_turn_any_model_into_an_inpainting_model/"
|
||||
)
|
||||
exit(-1)
|
||||
|
||||
inpaint_pipe = pipe_class(
|
||||
vae=pipe.vae,
|
||||
text_encoder=pipe.text_encoder,
|
||||
tokenizer=pipe.tokenizer,
|
||||
unet=pipe.unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=pipe.scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
name = "controlnet"
|
||||
@@ -116,10 +33,16 @@ class ControlNet(DiffusionInpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
model_info: ModelInfo = kwargs["model_info"]
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
sd_controlnet_method = controlnet_name_map.get(
|
||||
sd_controlnet_method, sd_controlnet_method
|
||||
)
|
||||
|
||||
model_kwargs = {
|
||||
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||
}
|
||||
self.model_info = model_info
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
|
||||
model_kwargs = {}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
@@ -133,41 +56,39 @@ class ControlNet(DiffusionInpaintModel):
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
|
||||
if sd_controlnet_method == "control_v11p_sd15_inpaint":
|
||||
from diffusers import StableDiffusionControlNetPipeline as PipeClass
|
||||
|
||||
self.is_native_control_inpaint = True
|
||||
else:
|
||||
from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass
|
||||
|
||||
self.is_native_control_inpaint = False
|
||||
|
||||
if self.is_native_control_inpaint:
|
||||
model_id = NATIVE_NAMES_MAP[kwargs["name"]]
|
||||
else:
|
||||
model_id = NAMES_MAP[kwargs["name"]]
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
]:
|
||||
from diffusers import (
|
||||
StableDiffusionControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
elif model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
from diffusers import (
|
||||
StableDiffusionXLControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype
|
||||
sd_controlnet_method, torch_dtype=torch_dtype
|
||||
)
|
||||
self.is_local_sd_model = False
|
||||
if kwargs.get("sd_local_model_path", None):
|
||||
self.is_local_sd_model = True
|
||||
self.model = load_from_local_model(
|
||||
kwargs["sd_local_model_path"],
|
||||
torch_dtype=torch_dtype,
|
||||
controlnet=controlnet,
|
||||
pipe_class=PipeClass,
|
||||
is_native_control_inpaint=self.is_native_control_inpaint,
|
||||
)
|
||||
if model_info.is_single_file_diffusers:
|
||||
self.model = PipeClass.from_single_file(
|
||||
model_info.path, controlnet=controlnet
|
||||
).to(torch_dtype)
|
||||
else:
|
||||
self.model = PipeClass.from_pretrained(
|
||||
model_id,
|
||||
model_info.path,
|
||||
controlnet=controlnet,
|
||||
revision="fp16" if use_gpu and fp16 else "main",
|
||||
revision="fp16"
|
||||
if (
|
||||
model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
and use_gpu
|
||||
and fp16
|
||||
)
|
||||
else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -191,6 +112,19 @@ class ControlNet(DiffusionInpaintModel):
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def _get_control_image(self, image, mask):
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
control_image = make_canny_control_image(image)
|
||||
elif "openpose" in self.sd_controlnet_method:
|
||||
control_image = make_openpose_control_image(image)
|
||||
elif "depth" in self.sd_controlnet_method:
|
||||
control_image = make_depth_control_image(image)
|
||||
elif "inpaint" in self.sd_controlnet_method:
|
||||
control_image = make_inpaint_control_image(image, mask)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.sd_controlnet_method} not implemented")
|
||||
return control_image
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@@ -206,84 +140,30 @@ class ControlNet(DiffusionInpaintModel):
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
control_image = self._get_control_image(image, mask)
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
image = PIL.Image.fromarray(image)
|
||||
|
||||
if self.is_native_control_inpaint:
|
||||
control_image = make_inpaint_condition(image, mask)
|
||||
output = self.model(
|
||||
prompt=config.prompt,
|
||||
image=control_image,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
||||
negative_prompt=config.negative_prompt,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
).images[0]
|
||||
else:
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate(
|
||||
[canny_image, canny_image, canny_image], axis=2
|
||||
)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
control_image = canny_image
|
||||
elif "openpose" in self.sd_controlnet_method:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
|
||||
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||
control_image = processor(image, hand_and_face=True)
|
||||
elif "depth" in self.sd_controlnet_method:
|
||||
from transformers import pipeline
|
||||
|
||||
depth_estimator = pipeline("depth-estimation")
|
||||
depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
|
||||
depth_image = np.array(depth_image)
|
||||
depth_image = depth_image[:, :, None]
|
||||
depth_image = np.concatenate(
|
||||
[depth_image, depth_image, depth_image], axis=2
|
||||
)
|
||||
control_image = PIL.Image.fromarray(depth_image)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.sd_controlnet_method} not implemented"
|
||||
)
|
||||
|
||||
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||
image = PIL.Image.fromarray(image)
|
||||
|
||||
output = self.model(
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
||||
).images[0]
|
||||
output = self.model(
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
control_image=control_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
def forward_post_process(self, result, image, mask, config):
|
||||
if config.sd_match_histograms:
|
||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
|
||||
if config.sd_mask_blur != 0:
|
||||
k = 2 * config.sd_mask_blur + 1
|
||||
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
return result, image, mask
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
|
||||
@@ -1626,6 +1626,7 @@ class FcF(InpaintModel):
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
seed = 0
|
||||
|
||||
46
lama_cleaner/model/helper/controlnet_preprocess.py
Normal file
46
lama_cleaner/model/helper/controlnet_preprocess.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import PIL
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_canny_control_image(image: np.ndarray) -> Image:
|
||||
canny_image = cv2.Canny(image, 100, 200)
|
||||
canny_image = canny_image[:, :, None]
|
||||
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||
canny_image = PIL.Image.fromarray(canny_image)
|
||||
control_image = canny_image
|
||||
return control_image
|
||||
|
||||
|
||||
def make_openpose_control_image(image: np.ndarray) -> Image:
|
||||
from controlnet_aux import OpenposeDetector
|
||||
|
||||
processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
||||
control_image = processor(image, hand_and_face=True)
|
||||
return control_image
|
||||
|
||||
|
||||
def make_depth_control_image(image: np.ndarray) -> Image:
|
||||
from transformers import pipeline
|
||||
|
||||
depth_estimator = pipeline("depth-estimation")
|
||||
depth_image = depth_estimator(PIL.Image.fromarray(image))["depth"]
|
||||
depth_image = np.array(depth_image)
|
||||
depth_image = depth_image[:, :, None]
|
||||
depth_image = np.concatenate([depth_image, depth_image, depth_image], axis=2)
|
||||
control_image = PIL.Image.fromarray(depth_image)
|
||||
return control_image
|
||||
|
||||
|
||||
def make_inpaint_control_image(image: np.ndarray, mask: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
"""
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image[mask[:, :, -1] > 128] = -1.0 # set as masked pixel
|
||||
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
||||
25
lama_cleaner/model/helper/cpu_text_encoder.py
Normal file
25
lama_cleaner/model/helper/cpu_text_encoder.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
del text_encoder
|
||||
torch_gc()
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
return [
|
||||
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||
.to(input_device)
|
||||
.to(self.torch_dtype)
|
||||
]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.torch_dtype
|
||||
@@ -17,7 +17,7 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs.get("local_files_only", False)}
|
||||
model_kwargs = {}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
@@ -77,16 +77,6 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
#
|
||||
# def forward_post_process(self, result, image, mask, config):
|
||||
# if config.sd_match_histograms:
|
||||
# result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||
#
|
||||
# if config.sd_mask_blur != 0:
|
||||
# k = 2 * config.sd_mask_blur + 1
|
||||
# mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||
# return result, image, mask
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
# model will be downloaded when app start, and can't switch in frontend settings
|
||||
|
||||
@@ -20,7 +20,6 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
model_kwargs = {
|
||||
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]),
|
||||
"torch_dtype": torch_dtype,
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ LAMA_MODEL_MD5 = os.environ.get("LAMA_MODEL_MD5", "e3aa4aaa15225a33ec84f9f4bc47e
|
||||
class LaMa(InpaintModel):
|
||||
name = "lama"
|
||||
pad_mod = 8
|
||||
is_erase_model = True
|
||||
|
||||
@staticmethod
|
||||
def download():
|
||||
|
||||
@@ -237,6 +237,7 @@ class LatentDiffusion(DDPM):
|
||||
class LDM(InpaintModel):
|
||||
name = "ldm"
|
||||
pad_mod = 32
|
||||
is_erase_model = True
|
||||
|
||||
def __init__(self, device, fp16: bool = True, **kwargs):
|
||||
self.fp16 = fp16
|
||||
|
||||
@@ -32,6 +32,7 @@ MANGA_LINE_MODEL_MD5 = os.environ.get(
|
||||
class Manga(InpaintModel):
|
||||
name = "manga"
|
||||
pad_mod = 16
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.inpaintor_model = load_jit_model(
|
||||
|
||||
@@ -1880,6 +1880,7 @@ class MAT(InpaintModel):
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
seed = 240 # pick up a random number
|
||||
|
||||
@@ -26,6 +26,7 @@ class MIGAN(InpaintModel):
|
||||
min_size = 512
|
||||
pad_mod = 512
|
||||
pad_to_square = True
|
||||
is_erase_model = True
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
self.model = load_jit_model(MIGAN_MODEL_URL, device, MIGAN_MODEL_MD5).eval()
|
||||
|
||||
@@ -8,6 +8,7 @@ flag_map = {"INPAINT_NS": cv2.INPAINT_NS, "INPAINT_TELEA": cv2.INPAINT_TELEA}
|
||||
class OpenCV2(InpaintModel):
|
||||
name = "cv2"
|
||||
pad_mod = 1
|
||||
is_erase_model = True
|
||||
|
||||
@staticmethod
|
||||
def is_downloaded() -> bool:
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
from typing import Union, List, Optional, Callable, Dict, Any
|
||||
|
||||
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
||||
@@ -217,6 +218,38 @@ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
num_in_channels=9,
|
||||
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
inpaint_pipe = cls(
|
||||
vae=pipe.vae,
|
||||
text_encoder=pipe.text_encoder,
|
||||
tokenizer=pipe.tokenizer,
|
||||
unet=pipe.unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=pipe.scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
@@ -6,34 +6,12 @@ import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class CPUTextEncoderWrapper(torch.nn.Module):
|
||||
def __init__(self, text_encoder, torch_dtype):
|
||||
super().__init__()
|
||||
self.config = text_encoder.config
|
||||
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||
self.torch_dtype = torch_dtype
|
||||
del text_encoder
|
||||
torch_gc()
|
||||
|
||||
def __call__(self, x, **kwargs):
|
||||
input_device = x.device
|
||||
return [
|
||||
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||
.to(input_device)
|
||||
.to(self.torch_dtype)
|
||||
]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.torch_dtype
|
||||
|
||||
|
||||
class SD(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
@@ -44,9 +22,7 @@ class SD(DiffusionInpaintModel):
|
||||
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
|
||||
model_kwargs = {
|
||||
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||
}
|
||||
model_kwargs = {}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
@@ -60,14 +36,20 @@ class SD(DiffusionInpaintModel):
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
if kwargs.get("sd_local_model_path", None):
|
||||
if os.path.isfile(self.model_id_or_path):
|
||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||
kwargs["sd_local_model_path"], torch_dtype=torch_dtype, **model_kwargs
|
||||
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
else:
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16" if use_gpu and fp16 else "main",
|
||||
revision="fp16"
|
||||
if (
|
||||
self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
and use_gpu
|
||||
and fp16
|
||||
)
|
||||
else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs,
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
@@ -13,26 +16,31 @@ class SDXL(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
model_id_or_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from diffusers.pipelines import AutoPipelineForInpainting
|
||||
from diffusers.pipelines import StableDiffusionXLInpaintPipeline
|
||||
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
|
||||
model_kwargs = {
|
||||
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||
}
|
||||
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
||||
revision="main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs,
|
||||
)
|
||||
if os.path.isfile(self.model_id_or_path):
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||
self.model_id_or_path, torch_dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
||||
)
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="main",
|
||||
torch_dtype=torch_dtype,
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
@@ -226,6 +226,7 @@ class ZITS(InpaintModel):
|
||||
min_size = 256
|
||||
pad_mod = 32
|
||||
pad_to_square = True
|
||||
is_erase_model = True
|
||||
|
||||
def __init__(self, device, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,49 +1,14 @@
|
||||
import torch
|
||||
import gc
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SD15_MODELS,
|
||||
MODELS_SUPPORT_FREEU,
|
||||
MODELS_SUPPORT_LCM_LORA,
|
||||
)
|
||||
from lama_cleaner.download import scan_models
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model.controlnet import ControlNet
|
||||
from lama_cleaner.model.fcf import FcF
|
||||
from lama_cleaner.model.kandinsky import Kandinsky22
|
||||
from lama_cleaner.model.lama import LaMa
|
||||
from lama_cleaner.model.ldm import LDM
|
||||
from lama_cleaner.model.manga import Manga
|
||||
from lama_cleaner.model.mat import MAT
|
||||
from lama_cleaner.model.mi_gan import MIGAN
|
||||
from lama_cleaner.model.paint_by_example import PaintByExample
|
||||
from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix
|
||||
from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14
|
||||
from lama_cleaner.model.sdxl import SDXL
|
||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.model.zits import ZITS
|
||||
from lama_cleaner.model.opencv2 import OpenCV2
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
models = {
|
||||
"lama": LaMa,
|
||||
"ldm": LDM,
|
||||
"zits": ZITS,
|
||||
"mat": MAT,
|
||||
"fcf": FcF,
|
||||
SD15.name: SD15,
|
||||
Anything4.name: Anything4,
|
||||
RealisticVision14.name: RealisticVision14,
|
||||
"cv2": OpenCV2,
|
||||
"manga": Manga,
|
||||
"sd2": SD2,
|
||||
"paint_by_example": PaintByExample,
|
||||
"instruct_pix2pix": InstructPix2Pix,
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
MIGAN.name: MIGAN,
|
||||
}
|
||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -51,23 +16,39 @@ class ModelManager:
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.kwargs = kwargs
|
||||
self.available_models: Dict[str, ModelInfo] = {}
|
||||
self.scan_models()
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
if name in SD15_MODELS and kwargs.get("sd_controlnet", False):
|
||||
return ControlNet(device, **{**kwargs, "name": name})
|
||||
for old_name, model_cls in models.items():
|
||||
if name == old_name and hasattr(model_cls, "model_id_or_path"):
|
||||
name = model_cls.model_id_or_path
|
||||
if name not in self.available_models:
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
if name in models:
|
||||
model = models[name](device, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
return model
|
||||
sd_controlnet_enabled = kwargs.get("sd_controlnet", False)
|
||||
model_info = self.available_models[name]
|
||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
||||
return models[name](device, **kwargs)
|
||||
|
||||
def is_downloaded(self, name: str) -> bool:
|
||||
if name in models:
|
||||
return models[name].is_downloaded()
|
||||
if sd_controlnet_enabled:
|
||||
return ControlNet(device, **{**kwargs, "model_info": model_info})
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported model: {name}")
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
f"When using non inpaint Stable Diffusion model, you must enable controlnet"
|
||||
)
|
||||
if model_info.model_type == ModelType.DIFFUSERS_SD_INPAINT:
|
||||
return SD(device, model_id_or_path=model_info.path, **kwargs)
|
||||
|
||||
if model_info.model_type == ModelType.DIFFUSERS_SDXL_INPAINT:
|
||||
return SDXL(device, model_id_or_path=model_info.path, **kwargs)
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
def __call__(self, image, mask, config: Config):
|
||||
self.switch_controlnet_method(control_method=config.controlnet_method)
|
||||
@@ -75,9 +56,18 @@ class ModelManager:
|
||||
self.enable_disable_lcm_lora(config)
|
||||
return self.model(image, mask, config)
|
||||
|
||||
def switch(self, new_name: str, **kwargs):
|
||||
def scan_models(self) -> List[ModelInfo]:
|
||||
available_models = scan_models()
|
||||
self.available_models = {it.name: it for it in available_models}
|
||||
return available_models
|
||||
|
||||
def switch(self, new_name: str):
|
||||
if new_name == self.name:
|
||||
return
|
||||
|
||||
old_name = self.name
|
||||
self.name = new_name
|
||||
|
||||
try:
|
||||
if torch.cuda.memory_allocated() > 0:
|
||||
# Clear current loaded model from memory
|
||||
@@ -88,8 +78,8 @@ class ModelManager:
|
||||
self.model = self.init_model(
|
||||
new_name, switch_mps_device(new_name, self.device), **self.kwargs
|
||||
)
|
||||
self.name = new_name
|
||||
except NotImplementedError as e:
|
||||
except Exception as e:
|
||||
self.name = old_name
|
||||
raise e
|
||||
|
||||
def switch_controlnet_method(self, control_method: str):
|
||||
@@ -97,27 +87,9 @@ class ModelManager:
|
||||
return
|
||||
if self.kwargs["sd_controlnet_method"] == control_method:
|
||||
return
|
||||
if not hasattr(self.model, "is_local_sd_model"):
|
||||
return
|
||||
|
||||
if self.model.is_local_sd_model:
|
||||
# is_native_control_inpaint 表示加载了普通 SD 模型
|
||||
if (
|
||||
self.model.is_native_control_inpaint
|
||||
and control_method != "control_v11p_sd15_inpaint"
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"--sd-local-model-path load a normal SD model, "
|
||||
f"to use {control_method} you should load an inpainting SD model"
|
||||
)
|
||||
elif (
|
||||
not self.model.is_native_control_inpaint
|
||||
and control_method == "control_v11p_sd15_inpaint"
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"--sd-local-model-path load an inpainting SD model, "
|
||||
f"to use {control_method} you should load a norml SD model"
|
||||
)
|
||||
if not self.available_models[self.name].support_controlnet():
|
||||
return
|
||||
|
||||
del self.model
|
||||
torch_gc()
|
||||
@@ -133,7 +105,7 @@ class ModelManager:
|
||||
if str(self.model.device) == "mps":
|
||||
return
|
||||
|
||||
if self.name in MODELS_SUPPORT_FREEU:
|
||||
if self.available_models[self.name].support_freeu():
|
||||
if config.sd_freeu:
|
||||
freeu_config = config.sd_freeu_config
|
||||
self.model.model.enable_freeu(
|
||||
@@ -146,7 +118,7 @@ class ModelManager:
|
||||
self.model.model.disable_freeu()
|
||||
|
||||
def enable_disable_lcm_lora(self, config: Config):
|
||||
if self.name in MODELS_SUPPORT_LCM_LORA:
|
||||
if self.available_models[self.name].support_lcm_lora():
|
||||
if config.sd_lcm_lora:
|
||||
if not self.model.model.pipe.get_list_adapters():
|
||||
self.model.model.load_lora_weights(self.model.lcm_lora_id)
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import *
|
||||
from lama_cleaner.download import cli_download_model
|
||||
from lama_cleaner.download import cli_download_model, scan_models
|
||||
from lama_cleaner.runtime import dump_environment_info
|
||||
|
||||
DOWNLOAD_SUBCOMMAND = "download"
|
||||
@@ -46,7 +46,11 @@ def parse_args():
|
||||
"--installer-config", default=None, help="Config file for windows installer"
|
||||
)
|
||||
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL, choices=AVAILABLE_MODELS)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=DEFAULT_MODEL,
|
||||
help=f"Available models: [{', '.join(AVAILABLE_MODELS)}], or model id on huggingface",
|
||||
)
|
||||
parser.add_argument("--no-half", action="store_true", help=NO_HALF_HELP)
|
||||
parser.add_argument("--cpu-offload", action="store_true", help=CPU_OFFLOAD_HELP)
|
||||
parser.add_argument("--disable-nsfw", action="store_true", help=DISABLE_NSFW_HELP)
|
||||
@@ -56,10 +60,9 @@ def parse_args():
|
||||
parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP)
|
||||
parser.add_argument(
|
||||
"--sd-controlnet-method",
|
||||
default=DEFAULT_CONTROLNET_METHOD,
|
||||
default=DEFAULT_SD_CONTROLNET_METHOD,
|
||||
choices=SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP)
|
||||
parser.add_argument(
|
||||
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
|
||||
)
|
||||
@@ -170,7 +173,8 @@ def parse_args():
|
||||
)
|
||||
#########
|
||||
|
||||
# useless args
|
||||
### useless args ###
|
||||
parser.add_argument("--sd-local-model-path", default=None, help=argparse.SUPPRESS)
|
||||
parser.add_argument("--debug", action="store_true", help=argparse.SUPPRESS)
|
||||
parser.add_argument("--hf_access_token", default="", help=argparse.SUPPRESS)
|
||||
parser.add_argument(
|
||||
@@ -180,6 +184,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--sd-enable-xformers", action="store_true", help=argparse.SUPPRESS
|
||||
)
|
||||
### end useless args ###
|
||||
|
||||
args = parser.parse_args()
|
||||
# collect system info to help debug
|
||||
@@ -251,6 +256,17 @@ def parse_args():
|
||||
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
||||
os.environ["U2NET_HOME"] = args.model_dir
|
||||
|
||||
if args.sd_run_local or args.local_files_only:
|
||||
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
||||
|
||||
if args.model not in AVAILABLE_MODELS:
|
||||
scanned_models = scan_models()
|
||||
if args.model not in [it.name for it in scanned_models]:
|
||||
parser.error(
|
||||
f"invalid --model: {args.model} not exists. Available models: {AVAILABLE_MODELS} or {scanned_models}"
|
||||
)
|
||||
|
||||
if args.input and args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
|
||||
@@ -4,6 +4,61 @@ from enum import Enum
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
INPAINT = "inpaint" # LaMa, MAT...
|
||||
DIFFUSERS_SD = "diffusers_sd"
|
||||
DIFFUSERS_SD_INPAINT = "diffusers_sd_inpaint"
|
||||
DIFFUSERS_SDXL = "diffusers_sdxl"
|
||||
DIFFUSERS_SDXL_INPAINT = "diffusers_sdxl_inpaint"
|
||||
DIFFUSERS_OTHER = "diffusers_other"
|
||||
|
||||
|
||||
FREEU_DEFAULT_CONFIGS = {
|
||||
ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
|
||||
}
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
def support_freeu(self) -> bool:
|
||||
return (
|
||||
self.model_type
|
||||
in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
or "instruct-pix2pix" in self.name
|
||||
)
|
||||
|
||||
|
||||
class HDStrategy(str, Enum):
|
||||
# Use original image size
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from lama_cleaner.diffusers_utils import scan_models
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import imghdr
|
||||
@@ -22,9 +20,9 @@ from loguru import logger
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SD15_MODELS,
|
||||
FREEU_DEFAULT_CONFIGS,
|
||||
MODELS_SUPPORT_FREEU,
|
||||
MODELS_SUPPORT_LCM_LORA,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
)
|
||||
from lama_cleaner.file_manager import FileManager
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
@@ -118,8 +116,8 @@ input_image_path: str = None
|
||||
is_disable_model_switch: bool = False
|
||||
is_controlnet: bool = False
|
||||
controlnet_method: str = "control_v11p_sd15_canny"
|
||||
is_enable_file_manager: bool = False
|
||||
is_enable_auto_saving: bool = False
|
||||
enable_file_manager: bool = False
|
||||
enable_auto_saving: bool = False
|
||||
is_desktop: bool = False
|
||||
image_quality: int = 95
|
||||
plugins = {}
|
||||
@@ -421,34 +419,35 @@ def run_plugin():
|
||||
|
||||
@app.route("/server_config", methods=["GET"])
|
||||
def get_server_config():
|
||||
controlnet = {
|
||||
"SD": SD_CONTROLNET_CHOICES,
|
||||
"SD2": SD2_CONTROLNET_CHOICES,
|
||||
"SDXL": SDXL_CONTROLNET_CHOICES,
|
||||
}
|
||||
return {
|
||||
"isControlNet": is_controlnet,
|
||||
"controlNetMethod": controlnet_method,
|
||||
"isDisableModelSwitchState": is_disable_model_switch,
|
||||
"isEnableAutoSaving": is_enable_auto_saving,
|
||||
"enableFileManager": is_enable_file_manager,
|
||||
"plugins": list(plugins.keys()),
|
||||
"freeSupportedModels": MODELS_SUPPORT_FREEU,
|
||||
"freeuDefaultConfigs": FREEU_DEFAULT_CONFIGS,
|
||||
"lcmLoraSupportedModels": MODELS_SUPPORT_LCM_LORA,
|
||||
"availableControlNet": controlnet,
|
||||
"enableFileManager": enable_file_manager,
|
||||
"enableAutoSaving": enable_auto_saving,
|
||||
}, 200
|
||||
|
||||
|
||||
@app.route("/sd_models", methods=["GET"])
|
||||
def get_diffusers_models():
|
||||
from diffusers.utils import DIFFUSERS_CACHE
|
||||
|
||||
return scan_models(DIFFUSERS_CACHE)
|
||||
@app.route("/models", methods=["GET"])
|
||||
def get_models():
|
||||
return [
|
||||
{
|
||||
**it.dict(),
|
||||
"support_lcm_lora": it.support_lcm_lora(),
|
||||
"support_controlnet": it.support_controlnet(),
|
||||
"support_freeu": it.support_freeu(),
|
||||
}
|
||||
for it in model.scan_models()
|
||||
]
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
||||
|
||||
@app.route("/model_downloaded/<name>")
|
||||
def model_downloaded(name):
|
||||
return str(model.is_downloaded(name)), 200
|
||||
return model.available_models[model.name].dict(), 200
|
||||
|
||||
|
||||
@app.route("/is_desktop")
|
||||
@@ -467,8 +466,10 @@ def switch_model():
|
||||
|
||||
try:
|
||||
model.switch(new_name)
|
||||
except NotImplementedError:
|
||||
return f"{new_name} not implemented", 403
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(error_message)
|
||||
return f"Switch model failed: {error_message}", 500
|
||||
return f"ok, switch to {new_name}", 200
|
||||
|
||||
|
||||
@@ -478,7 +479,7 @@ def index():
|
||||
|
||||
|
||||
@app.route("/inputimage")
|
||||
def set_input_photo():
|
||||
def get_cli_input_image():
|
||||
if input_image_path:
|
||||
with open(input_image_path, "rb") as f:
|
||||
image_in_bytes = f.read()
|
||||
@@ -547,11 +548,10 @@ def main(args):
|
||||
global device
|
||||
global input_image_path
|
||||
global is_disable_model_switch
|
||||
global is_enable_file_manager
|
||||
global enable_file_manager
|
||||
global is_desktop
|
||||
global thumb
|
||||
global output_dir
|
||||
global is_enable_auto_saving
|
||||
global is_controlnet
|
||||
global controlnet_method
|
||||
global image_quality
|
||||
@@ -566,7 +566,9 @@ def main(args):
|
||||
|
||||
output_dir = args.output_dir
|
||||
if output_dir:
|
||||
is_enable_auto_saving = True
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
logger.info(f"Output dir: {output_dir}")
|
||||
enable_auto_saving = True
|
||||
|
||||
device = torch.device(args.device)
|
||||
is_disable_model_switch = args.disable_model_switch
|
||||
@@ -579,12 +581,12 @@ def main(args):
|
||||
if args.input and os.path.isdir(args.input):
|
||||
logger.info(f"Initialize file manager")
|
||||
thumb = FileManager(app)
|
||||
is_enable_file_manager = True
|
||||
enable_file_manager = True
|
||||
app.config["THUMBNAIL_MEDIA_ROOT"] = args.input
|
||||
app.config["THUMBNAIL_MEDIA_THUMBNAIL_ROOT"] = os.path.join(
|
||||
args.output_dir, "lama_cleaner_thumbnails"
|
||||
output_dir, "lama_cleaner_thumbnails"
|
||||
)
|
||||
thumb.output_dir = Path(args.output_dir)
|
||||
thumb.output_dir = Path(output_dir)
|
||||
# thumb.start()
|
||||
# try:
|
||||
# while True:
|
||||
|
||||
Reference in New Issue
Block a user