lots update
This commit is contained in:
@@ -4,6 +4,11 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
||||
|
||||
MPS_UNSUPPORT_MODELS = [
|
||||
"lama",
|
||||
"ldm",
|
||||
@@ -15,22 +20,8 @@ MPS_UNSUPPORT_MODELS = [
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "lama"
|
||||
AVAILABLE_MODELS = [
|
||||
"lama",
|
||||
"ldm",
|
||||
"zits",
|
||||
"mat",
|
||||
"fcf",
|
||||
"manga",
|
||||
"cv2",
|
||||
]
|
||||
DIFFUSERS_MODEL_FP16_REVERSION = [
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
"Sanster/anything-4.0-inpainting",
|
||||
"Sanster/Realistic_Vision_V1.4-inpainting",
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
"timbrooks/instruct-pix2pix",
|
||||
]
|
||||
AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
|
||||
|
||||
|
||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||
DEFAULT_DEVICE = "cuda"
|
||||
|
||||
@@ -5,23 +5,23 @@ from typing import List
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
|
||||
from lama_cleaner.runtime import setup_model_dir
|
||||
from lama_cleaner.schema import (
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
from lama_cleaner.const import (
|
||||
DEFAULT_MODEL_DIR,
|
||||
DIFFUSERS_SD_CLASS_NAME,
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
)
|
||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||
from lama_cleaner.runtime import setup_model_dir
|
||||
|
||||
|
||||
def cli_download_model(model: str, model_dir: Path):
|
||||
setup_model_dir(model_dir)
|
||||
from lama_cleaner.model import models
|
||||
|
||||
if model in models:
|
||||
if model in models and models[model].is_erase_model:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
@@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path):
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
downloaded_path = DiffusionPipeline.download(
|
||||
downloaded_path = handle_from_pretrained_exceptions(
|
||||
DiffusionPipeline.download,
|
||||
pretrained_model_name=model,
|
||||
variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
|
||||
variant="fp16",
|
||||
resume_download=True,
|
||||
)
|
||||
logger.info(f"Done. Downloaded to {downloaded_path}")
|
||||
@@ -43,21 +44,33 @@ def folder_name_to_show_name(name: str) -> str:
|
||||
|
||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
|
||||
res = []
|
||||
for it in cache_dir.glob(f"*.*"):
|
||||
for it in stable_diffusion_dir.glob(f"*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
if "inpaint" in str(it).lower():
|
||||
if "sdxl" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
else:
|
||||
if "sdxl" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=str(it.absolute()),
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
)
|
||||
|
||||
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
if "inpaint" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
@@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]:
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name in diffusers_model_names:
|
||||
continue
|
||||
|
||||
if _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||
if "PowerPaint" in name:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
|
||||
@@ -290,3 +290,7 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
|
||||
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
|
||||
else:
|
||||
return mask
|
||||
|
||||
|
||||
def is_mac():
|
||||
return sys.platform == "darwin"
|
||||
|
||||
@@ -9,6 +9,7 @@ from .mat import MAT
|
||||
from .mi_gan import MIGAN
|
||||
from .opencv2 import OpenCV2
|
||||
from .paint_by_example import PaintByExample
|
||||
from .power_paint.power_paint import PowerPaint
|
||||
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
|
||||
from .sdxl import SDXL
|
||||
from .zits import ZITS
|
||||
@@ -30,4 +31,5 @@ models = {
|
||||
InstructPix2Pix.name: InstructPix2Pix,
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
PowerPaint.name: PowerPaint,
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ from lama_cleaner.helper import (
|
||||
)
|
||||
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo
|
||||
from lama_cleaner.schema import Config, HDStrategy, SDSampler
|
||||
|
||||
|
||||
class InpaintModel:
|
||||
@@ -271,7 +271,7 @@ class InpaintModel:
|
||||
|
||||
class DiffusionInpaintModel(InpaintModel):
|
||||
def __init__(self, device, **kwargs):
|
||||
self.model_info: ModelInfo = kwargs["model_info"]
|
||||
self.model_info = kwargs["model_info"]
|
||||
self.model_id_or_path = self.model_info.path
|
||||
super().__init__(device, **kwargs)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
from diffusers import ControlNetModel, DiffusionPipeline
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.helper.controlnet_preprocess import (
|
||||
make_canny_control_image,
|
||||
@@ -14,8 +13,8 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
|
||||
make_inpaint_control_image,
|
||||
)
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.model.utils import get_scheduler
|
||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
||||
from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
|
||||
from lama_cleaner.schema import Config, ModelType
|
||||
|
||||
|
||||
class ControlNet(DiffusionInpaintModel):
|
||||
@@ -39,11 +38,11 @@ class ControlNet(DiffusionInpaintModel):
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
model_info: ModelInfo = kwargs["model_info"]
|
||||
sd_controlnet_method = kwargs["sd_controlnet_method"]
|
||||
model_info = kwargs["model_info"]
|
||||
controlnet_method = kwargs["controlnet_method"]
|
||||
|
||||
self.model_info = model_info
|
||||
self.sd_controlnet_method = sd_controlnet_method
|
||||
self.controlnet_method = controlnet_method
|
||||
|
||||
model_kwargs = {}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
@@ -76,7 +75,8 @@ class ControlNet(DiffusionInpaintModel):
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
|
||||
pretrained_model_name_or_path=controlnet_method,
|
||||
resume_download=True,
|
||||
)
|
||||
if model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
@@ -88,17 +88,12 @@ class ControlNet(DiffusionInpaintModel):
|
||||
model_info.path, controlnet=controlnet, **model_kwargs
|
||||
).to(torch_dtype)
|
||||
else:
|
||||
self.model = PipeClass.from_pretrained(
|
||||
model_info.path,
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
PipeClass.from_pretrained,
|
||||
pretrained_model_name_or_path=model_info.path,
|
||||
controlnet=controlnet,
|
||||
revision="fp16"
|
||||
if (
|
||||
model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
and use_gpu
|
||||
and fp16
|
||||
)
|
||||
else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
variant="fp16",
|
||||
dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -116,23 +111,23 @@ class ControlNet(DiffusionInpaintModel):
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def switch_controlnet_method(self, new_method: str):
|
||||
self.sd_controlnet_method = new_method
|
||||
self.controlnet_method = new_method
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
new_method, torch_dtype=self.torch_dtype, resume_download=True
|
||||
).to(self.model.device)
|
||||
self.model.controlnet = controlnet
|
||||
|
||||
def _get_control_image(self, image, mask):
|
||||
if "canny" in self.sd_controlnet_method:
|
||||
if "canny" in self.controlnet_method:
|
||||
control_image = make_canny_control_image(image)
|
||||
elif "openpose" in self.sd_controlnet_method:
|
||||
elif "openpose" in self.controlnet_method:
|
||||
control_image = make_openpose_control_image(image)
|
||||
elif "depth" in self.sd_controlnet_method:
|
||||
elif "depth" in self.controlnet_method:
|
||||
control_image = make_depth_control_image(image)
|
||||
elif "inpaint" in self.sd_controlnet_method:
|
||||
elif "inpaint" in self.controlnet_method:
|
||||
control_image = make_inpaint_control_image(image, mask)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.sd_controlnet_method} not implemented")
|
||||
raise NotImplementedError(f"{self.controlnet_method} not implemented")
|
||||
return control_image
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
|
||||
@@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
}
|
||||
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
self.model_id_or_path, **model_kwargs
|
||||
self.name, **model_kwargs
|
||||
).to(device)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
@@ -66,4 +66,3 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
|
||||
class Kandinsky22(Kandinsky):
|
||||
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
|
||||
@@ -16,7 +16,7 @@ from lama_cleaner.model.base import InpaintModel
|
||||
|
||||
MIGAN_MODEL_URL = os.environ.get(
|
||||
"MIGAN_MODEL_URL",
|
||||
"/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt",
|
||||
"https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
|
||||
)
|
||||
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
)
|
||||
|
||||
self.model = DiffusionPipeline.from_pretrained(
|
||||
"Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
|
||||
self.name, torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
|
||||
# TODO: gpu_id
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .pipeline_stable_diffusion_controlnet_inpaint import (
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
)
|
||||
@@ -1,638 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
from typing import Union, List, Optional, Callable, Dict, Any
|
||||
|
||||
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
||||
|
||||
import torch
|
||||
import PIL.Image
|
||||
|
||||
from diffusers.pipelines.controlnet.pipeline_controlnet import *
|
||||
from diffusers.utils import replace_example_docstring
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
... )
|
||||
>>> image = np.array(image)
|
||||
>>> mask_image = load_image(
|
||||
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
... )
|
||||
>>> mask_image = np.array(mask_image)
|
||||
>>> # get canny image
|
||||
>>> canny_image = cv2.Canny(image, 100, 200)
|
||||
>>> canny_image = canny_image[:, :, None]
|
||||
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||
>>> canny_image = Image.fromarray(canny_image)
|
||||
|
||||
>>> # load control net and stable diffusion v1-5
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
|
||||
>>> # speed up diffusion process with faster scheduler and memory optimization
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> # remove following line if xformers is not installed
|
||||
>>> pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # generate image
|
||||
>>> generator = torch.manual_seed(0)
|
||||
>>> image = pipe(
|
||||
... "futuristic-looking doggo",
|
||||
... num_inference_steps=20,
|
||||
... generator=generator,
|
||||
... image=image,
|
||||
... control_image=canny_image,
|
||||
... mask_image=mask_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask):
|
||||
"""
|
||||
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||
``image`` and ``1`` for the ``mask``.
|
||||
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||
Args:
|
||||
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||
Raises:
|
||||
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||
(ot the other way around).
|
||||
Returns:
|
||||
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||
dimensions: ``batch x channels x height x width``.
|
||||
"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
|
||||
)
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert (
|
||||
image.shape[0] == 3
|
||||
), "Image outside a batch should be of shape (3, H, W)"
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert (
|
||||
image.ndim == 4 and mask.ndim == 4
|
||||
), "Image and Mask must have 4 dimensions"
|
||||
assert (
|
||||
image.shape[-2:] == mask.shape[-2:]
|
||||
), "Image and Mask must have the same spatial dimensions"
|
||||
assert (
|
||||
image.shape[0] == mask.shape[0]
|
||||
), "Image and Mask must have the same batch size"
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("Image should be in [-1, 1] range")
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError("Mask should be in [0, 1] range")
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
|
||||
)
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetModel`]):
|
||||
Provides additional conditioning to the unet during the denoising process
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
num_in_channels=9,
|
||||
from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
|
||||
device="cpu",
|
||||
load_safety_checker=False,
|
||||
)
|
||||
|
||||
inpaint_pipe = cls(
|
||||
vae=pipe.vae,
|
||||
text_encoder=pipe.text_encoder,
|
||||
tokenizer=pipe.tokenizer,
|
||||
unet=pipe.unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=pipe.scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
return inpaint_pipe
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
|
||||
generator=generator[i]
|
||||
)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
|
||||
generator=generator
|
||||
)
|
||||
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else masked_image_latents
|
||||
)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
if isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[3]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[2]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
control_image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
] = None,
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
||||
be masked out with `mask_image` and repainted according to `prompt`.
|
||||
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
|
||||
also be accepted as an image. The control image is automatically resized to fit the output image.
|
||||
mask_image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet.
|
||||
Examples:
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, control_image)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
image=control_image,
|
||||
callback_steps=callback_steps,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
control_image = self.prepare_image(
|
||||
control_image,
|
||||
width,
|
||||
height,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
self.controlnet.dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
control_image = torch.cat([control_image] * 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.controlnet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# EXTRA: prepare mask latents
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=control_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
down_block_res_samples = [
|
||||
down_block_res_sample * controlnet_conditioning_scale
|
||||
for down_block_res_sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample *= controlnet_conditioning_scale
|
||||
|
||||
# predict the noise residual
|
||||
latent_model_input = torch.cat(
|
||||
[latent_model_input, mask, masked_image_latents], dim=1
|
||||
)
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
elif output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(
|
||||
image, device, prompt_embeds.dtype
|
||||
)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(
|
||||
image, device, prompt_embeds.dtype
|
||||
)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
||||
0
lama_cleaner/model/power_paint/__init__.py
Normal file
0
lama_cleaner/model/power_paint/__init__.py
Normal file
1243
lama_cleaner/model/power_paint/pipeline_powerpaint.py
Normal file
1243
lama_cleaner/model/power_paint/pipeline_powerpaint.py
Normal file
File diff suppressed because it is too large
Load Diff
1775
lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
Normal file
1775
lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
Normal file
File diff suppressed because it is too large
Load Diff
96
lama_cleaner/model/power_paint/power_paint.py
Normal file
96
lama_cleaner/model/power_paint/power_paint.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from PIL import Image
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||
from lama_cleaner.schema import Config
|
||||
from .powerpaint_tokenizer import add_task_to_prompt
|
||||
|
||||
|
||||
class PowerPaint(DiffusionInpaintModel):
|
||||
name = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .pipeline_powerpaint import StableDiffusionInpaintPipeline
|
||||
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||
|
||||
fp16 = not kwargs.get("no_half", False)
|
||||
model_kwargs = {}
|
||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionInpaintPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.name,
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
||||
config.prompt, config.negative_prompt, config.powerpaint_task
|
||||
)
|
||||
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
promptA=promptA,
|
||||
promptB=promptB,
|
||||
tradoff=config.fitting_degree,
|
||||
tradoff_nag=config.fitting_degree,
|
||||
negative_promptA=negative_promptA,
|
||||
negative_promptB=negative_promptB,
|
||||
mask_image=PIL.Image.fromarray(mask[:, :, -1], mode="L"),
|
||||
num_inference_steps=config.sd_steps,
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
540
lama_cleaner/model/power_paint/powerpaint_tokenizer.py
Normal file
540
lama_cleaner/model/power_paint/powerpaint_tokenizer.py
Normal file
@@ -0,0 +1,540 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, List, Optional, Union
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from lama_cleaner.schema import PowerPaintTask
|
||||
|
||||
|
||||
def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
||||
if task == PowerPaintTask.object_remove:
|
||||
promptA = prompt + " P_ctxt"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt + " P_obj"
|
||||
negative_promptB = negative_prompt + " P_obj"
|
||||
elif task == PowerPaintTask.shape_guided:
|
||||
promptA = prompt + " P_shape"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt
|
||||
negative_promptB = negative_prompt
|
||||
elif task == PowerPaintTask.outpainting:
|
||||
promptA = prompt + " P_ctxt"
|
||||
promptB = prompt + " P_ctxt"
|
||||
negative_promptA = negative_prompt + " P_obj"
|
||||
negative_promptB = negative_prompt + " P_obj"
|
||||
else:
|
||||
promptA = prompt + " P_obj"
|
||||
promptB = prompt + " P_obj"
|
||||
negative_promptA = negative_prompt
|
||||
negative_promptB = negative_prompt
|
||||
|
||||
return promptA, promptB, negative_promptA, negative_promptB
|
||||
|
||||
|
||||
class PowerPaintTokenizer:
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.wrapped = tokenizer
|
||||
self.token_map = {}
|
||||
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"]
|
||||
num_vec_per_token = 10
|
||||
for placeholder_token in placeholder_tokens:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
output.append(ith_token)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "wrapped":
|
||||
return super().__getattr__("wrapped")
|
||||
|
||||
try:
|
||||
return getattr(self.wrapped, name)
|
||||
except AttributeError:
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"'name' cannot be found in both "
|
||||
f"'{self.__class__.__name__}' and "
|
||||
f"'{self.__class__.__name__}.tokenizer'."
|
||||
)
|
||||
|
||||
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
||||
"""Attempt to add tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
tokens (Union[str, List[str]]): The tokens to be added.
|
||||
"""
|
||||
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
||||
assert num_added_tokens != 0, (
|
||||
f"The tokenizer already contains the token {tokens}. Please pass "
|
||||
"a different `placeholder_token` that is not already in the "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
def get_token_info(self, token: str) -> dict:
|
||||
"""Get the information of a token, including its start and end index in
|
||||
the current tokenizer.
|
||||
|
||||
Args:
|
||||
token (str): The token to be queried.
|
||||
|
||||
Returns:
|
||||
dict: The information of the token, including its start and end
|
||||
index in current tokenizer.
|
||||
"""
|
||||
token_ids = self.__call__(token).input_ids
|
||||
start, end = token_ids[1], token_ids[-2] + 1
|
||||
return {"name": token, "start": start, "end": end}
|
||||
|
||||
def add_placeholder_token(
|
||||
self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs
|
||||
):
|
||||
"""Add placeholder tokens to the tokenizer.
|
||||
|
||||
Args:
|
||||
placeholder_token (str): The placeholder token to be added.
|
||||
num_vec_per_token (int, optional): The number of vectors of
|
||||
the added placeholder token.
|
||||
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
||||
"""
|
||||
output = []
|
||||
if num_vec_per_token == 1:
|
||||
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||
output.append(placeholder_token)
|
||||
else:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||
output.append(ith_token)
|
||||
|
||||
for token in self.token_map:
|
||||
if token in placeholder_token:
|
||||
raise ValueError(
|
||||
f"The tokenizer already has placeholder token {token} "
|
||||
f"that can get confused with {placeholder_token} "
|
||||
"keep placeholder tokens independent"
|
||||
)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def replace_placeholder_tokens_in_text(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the keywords in text with placeholder tokens. This function
|
||||
will be called in `self.__call__` and `self.encode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(
|
||||
self.replace_placeholder_tokens_in_text(
|
||||
text[i], vector_shuffle=vector_shuffle
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
for placeholder_token in self.token_map:
|
||||
if placeholder_token in text:
|
||||
tokens = self.token_map[placeholder_token]
|
||||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||
if vector_shuffle:
|
||||
tokens = copy.copy(tokens)
|
||||
random.shuffle(tokens)
|
||||
text = text.replace(placeholder_token, " ".join(tokens))
|
||||
return text
|
||||
|
||||
def replace_text_with_placeholder_tokens(
|
||||
self, text: Union[str, List[str]]
|
||||
) -> Union[str, List[str]]:
|
||||
"""Replace the placeholder tokens in text with the original keywords.
|
||||
This function will be called in `self.decode`.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be processed.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The processed text.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
||||
return output
|
||||
|
||||
for placeholder_token, tokens in self.token_map.items():
|
||||
merged_tokens = " ".join(tokens)
|
||||
if merged_tokens in text:
|
||||
text = text.replace(merged_tokens, placeholder_token)
|
||||
return text
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
*args,
|
||||
vector_shuffle: bool = False,
|
||||
prop_tokens_to_load: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""The call function of the wrapper.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be tokenized.
|
||||
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
||||
Defaults to False.
|
||||
prop_tokens_to_load (float, optional): The proportion of tokens to
|
||||
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
)
|
||||
|
||||
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
||||
|
||||
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
||||
"""Encode the passed text to token index.
|
||||
|
||||
Args:
|
||||
text (Union[str, List[str]]): The text to be encode.
|
||||
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
||||
"""
|
||||
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
||||
return self.wrapped(replaced_text, *args, **kwargs)
|
||||
|
||||
def decode(
|
||||
self, token_ids, return_raw: bool = False, *args, **kwargs
|
||||
) -> Union[str, List[str]]:
|
||||
"""Decode the token index to text.
|
||||
|
||||
Args:
|
||||
token_ids: The token index to be decoded.
|
||||
return_raw: Whether keep the placeholder token in the text.
|
||||
Defaults to False.
|
||||
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
||||
|
||||
Returns:
|
||||
Union[str, List[str]]: The decoded text.
|
||||
"""
|
||||
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
||||
if return_raw:
|
||||
return text
|
||||
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||
return replaced_text
|
||||
|
||||
|
||||
class EmbeddingLayerWithFixes(nn.Module):
|
||||
"""The revised embedding layer to support external embeddings. This design
|
||||
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
||||
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
||||
jack.py#L224 # noqa.
|
||||
|
||||
Args:
|
||||
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
||||
external_embeddings (Union[dict, List[dict]], optional): The external
|
||||
embeddings added to this layer. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
wrapped: nn.Embedding,
|
||||
external_embeddings: Optional[Union[dict, List[dict]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.num_embeddings = wrapped.weight.shape[0]
|
||||
|
||||
self.external_embeddings = []
|
||||
if external_embeddings:
|
||||
self.add_embeddings(external_embeddings)
|
||||
|
||||
self.trainable_embeddings = nn.ParameterDict()
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
"""Get the weight of wrapped embedding layer."""
|
||||
return self.wrapped.weight
|
||||
|
||||
def check_duplicate_names(self, embeddings: List[dict]):
|
||||
"""Check whether duplicate names exist in list of 'external
|
||||
embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
names = [emb["name"] for emb in embeddings]
|
||||
assert len(names) == len(set(names)), (
|
||||
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
||||
)
|
||||
|
||||
def check_ids_overlap(self, embeddings):
|
||||
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
||||
|
||||
Args:
|
||||
embeddings (List[dict]): A list of embedding to be check.
|
||||
"""
|
||||
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
||||
ids_range.sort() # sort by 'start'
|
||||
# check if 'end' has overlapping
|
||||
for idx in range(len(ids_range) - 1):
|
||||
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
||||
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
||||
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
||||
)
|
||||
|
||||
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
||||
"""Add external embeddings to this layer.
|
||||
|
||||
Use case:
|
||||
|
||||
>>> 1. Add token to tokenizer and get the token id.
|
||||
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
|
||||
>>> # 'how much' in kiswahili
|
||||
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
|
||||
>>>
|
||||
>>> 2. Add external embeddings to the model.
|
||||
>>> new_embedding = {
|
||||
>>> 'name': 'ngapi', # 'how much' in kiswahili
|
||||
>>> 'embedding': torch.ones(1, 15) * 4,
|
||||
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
|
||||
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
|
||||
>>> 'trainable': False # if True, will registry as a parameter
|
||||
>>> }
|
||||
>>> embedding_layer = nn.Embedding(10, 15)
|
||||
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
|
||||
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
|
||||
>>>
|
||||
>>> 3. Forward tokenizer and embedding layer!
|
||||
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
|
||||
>>> input_ids = tokenizer(
|
||||
>>> input_text, padding='max_length', truncation=True,
|
||||
>>> return_tensors='pt')['input_ids']
|
||||
>>> out_feat = embedding_layer_wrapper(input_ids)
|
||||
>>>
|
||||
>>> 4. Let's validate the result!
|
||||
>>> assert (out_feat[0, 3: 7] == 2.3).all()
|
||||
>>> assert (out_feat[2, 5: 9] == 2.3).all()
|
||||
|
||||
Args:
|
||||
embeddings (Union[dict, list[dict]]): The external embeddings to
|
||||
be added. Each dict must contain the following 4 fields: 'name'
|
||||
(the name of this embedding), 'embedding' (the embedding
|
||||
tensor), 'start' (the start token id of this embedding), 'end'
|
||||
(the end token id of this embedding). For example:
|
||||
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
||||
"""
|
||||
if isinstance(embeddings, dict):
|
||||
embeddings = [embeddings]
|
||||
|
||||
self.external_embeddings += embeddings
|
||||
self.check_duplicate_names(self.external_embeddings)
|
||||
self.check_ids_overlap(self.external_embeddings)
|
||||
|
||||
# set for trainable
|
||||
added_trainable_emb_info = []
|
||||
for embedding in embeddings:
|
||||
trainable = embedding.get("trainable", False)
|
||||
if trainable:
|
||||
name = embedding["name"]
|
||||
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
||||
self.trainable_embeddings[name] = embedding["embedding"]
|
||||
added_trainable_emb_info.append(name)
|
||||
|
||||
added_emb_info = [emb["name"] for emb in embeddings]
|
||||
added_emb_info = ", ".join(added_emb_info)
|
||||
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
||||
|
||||
if added_trainable_emb_info:
|
||||
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
||||
print(
|
||||
"Successfully add trainable external embeddings: "
|
||||
f"{added_trainable_emb_info}",
|
||||
"current",
|
||||
)
|
||||
|
||||
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""Replace external input ids to 0.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The input ids to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced input ids.
|
||||
"""
|
||||
input_ids_fwd = input_ids.clone()
|
||||
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
||||
return input_ids_fwd
|
||||
|
||||
def replace_embeddings(
|
||||
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
||||
) -> torch.Tensor:
|
||||
"""Replace external embedding to the embedding layer. Noted that, in
|
||||
this function we use `torch.cat` to avoid inplace modification.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The original token ids. Shape like
|
||||
[LENGTH, ].
|
||||
embedding (torch.Tensor): The embedding of token ids after
|
||||
`replace_input_ids` function.
|
||||
external_embedding (dict): The external embedding to be replaced.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The replaced embedding.
|
||||
"""
|
||||
new_embedding = []
|
||||
|
||||
name = external_embedding["name"]
|
||||
start = external_embedding["start"]
|
||||
end = external_embedding["end"]
|
||||
target_ids_to_replace = [i for i in range(start, end)]
|
||||
ext_emb = external_embedding["embedding"]
|
||||
|
||||
# do not need to replace
|
||||
if not (input_ids == start).any():
|
||||
return embedding
|
||||
|
||||
# start replace
|
||||
s_idx, e_idx = 0, 0
|
||||
while e_idx < len(input_ids):
|
||||
if input_ids[e_idx] == start:
|
||||
if e_idx != 0:
|
||||
# add embedding do not need to replace
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
# check if the next embedding need to replace is valid
|
||||
actually_ids_to_replace = [
|
||||
int(i) for i in input_ids[e_idx : e_idx + end - start]
|
||||
]
|
||||
assert actually_ids_to_replace == target_ids_to_replace, (
|
||||
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
||||
f"Expect '{target_ids_to_replace}' for embedding "
|
||||
f"'{name}' but found '{actually_ids_to_replace}'."
|
||||
)
|
||||
|
||||
new_embedding.append(ext_emb)
|
||||
|
||||
s_idx = e_idx + end - start
|
||||
e_idx = s_idx + 1
|
||||
else:
|
||||
e_idx += 1
|
||||
|
||||
if e_idx == len(input_ids):
|
||||
new_embedding.append(embedding[s_idx:e_idx])
|
||||
|
||||
return torch.cat(new_embedding, dim=0)
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
|
||||
):
|
||||
"""The forward function.
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
||||
[LENGTH, ].
|
||||
external_embeddings (Optional[List[dict]]): The external
|
||||
embeddings. If not passed, only `self.external_embeddings`
|
||||
will be used. Defaults to None.
|
||||
|
||||
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
||||
"""
|
||||
assert input_ids.ndim in [1, 2]
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
if external_embeddings is None and not self.external_embeddings:
|
||||
return self.wrapped(input_ids)
|
||||
|
||||
input_ids_fwd = self.replace_input_ids(input_ids)
|
||||
inputs_embeds = self.wrapped(input_ids_fwd)
|
||||
|
||||
vecs = []
|
||||
|
||||
if external_embeddings is None:
|
||||
external_embeddings = []
|
||||
elif isinstance(external_embeddings, dict):
|
||||
external_embeddings = [external_embeddings]
|
||||
embeddings = self.external_embeddings + external_embeddings
|
||||
|
||||
for input_id, embedding in zip(input_ids, inputs_embeds):
|
||||
new_embedding = embedding
|
||||
for external_embedding in embeddings:
|
||||
new_embedding = self.replace_embeddings(
|
||||
input_id, new_embedding, external_embedding
|
||||
)
|
||||
vecs.append(new_embedding)
|
||||
|
||||
return torch.stack(vecs)
|
||||
|
||||
|
||||
def add_tokens(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
placeholder_tokens: list,
|
||||
initialize_tokens: list = None,
|
||||
num_vectors_per_token: int = 1,
|
||||
):
|
||||
"""Add token for training.
|
||||
|
||||
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
||||
"""
|
||||
if initialize_tokens is not None:
|
||||
assert len(initialize_tokens) == len(
|
||||
placeholder_tokens
|
||||
), "placeholder_token should be the same length as initialize_token"
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
tokenizer.add_placeholder_token(
|
||||
placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
|
||||
)
|
||||
|
||||
# text_encoder.set_embedding_layer()
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
|
||||
embedding_layer
|
||||
)
|
||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
||||
|
||||
assert embedding_layer is not None, (
|
||||
"Do not support get embedding layer for current text encoder. "
|
||||
"Please check your configuration."
|
||||
)
|
||||
initialize_embedding = []
|
||||
if initialize_tokens is not None:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
initialize_embedding.append(
|
||||
temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
|
||||
)
|
||||
else:
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
init_id = tokenizer("a").input_ids[1]
|
||||
temp_embedding = embedding_layer.weight[init_id]
|
||||
len_emb = temp_embedding.shape[0]
|
||||
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
||||
initialize_embedding.append(init_weight)
|
||||
|
||||
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
||||
|
||||
token_info_all = []
|
||||
for ii in range(len(placeholder_tokens)):
|
||||
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
||||
token_info["embedding"] = initialize_embedding[ii]
|
||||
token_info["trainable"] = True
|
||||
token_info_all.append(token_info)
|
||||
embedding_layer.add_embeddings(token_info_all)
|
||||
@@ -3,9 +3,9 @@ import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||
from lama_cleaner.schema import Config, ModelType
|
||||
|
||||
|
||||
@@ -40,20 +40,18 @@ class SD(DiffusionInpaintModel):
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||
self.model_id_or_path, torch_dtype=torch_dtype, **model_kwargs
|
||||
self.model_id_or_path, dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
else:
|
||||
self.model = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="fp16"
|
||||
if self.model_id_or_path in DIFFUSERS_MODEL_FP16_REVERSION
|
||||
else "main",
|
||||
torch_dtype=torch_dtype,
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionInpaintPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
variant="fp16",
|
||||
dtype=torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
# TODO: gpu_id
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
@@ -98,20 +96,20 @@ class SD(DiffusionInpaintModel):
|
||||
|
||||
|
||||
class SD15(SD):
|
||||
name = "sd1.5"
|
||||
name = "runwayml/stable-diffusion-inpainting"
|
||||
model_id_or_path = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
|
||||
class Anything4(SD):
|
||||
name = "anything4"
|
||||
name = "Sanster/anything-4.0-inpainting"
|
||||
model_id_or_path = "Sanster/anything-4.0-inpainting"
|
||||
|
||||
|
||||
class RealisticVision14(SD):
|
||||
name = "realisticVision1.4"
|
||||
name = "Sanster/Realistic_Vision_V1.4-inpainting"
|
||||
model_id_or_path = "Sanster/Realistic_Vision_V1.4-inpainting"
|
||||
|
||||
|
||||
class SD2(SD):
|
||||
name = "sd2"
|
||||
name = "stabilityai/stable-diffusion-2-inpainting"
|
||||
model_id_or_path = "stabilityai/stable-diffusion-2-inpainting"
|
||||
|
||||
@@ -8,11 +8,12 @@ from diffusers import AutoencoderKL
|
||||
from loguru import logger
|
||||
|
||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||
from lama_cleaner.model.utils import handle_from_pretrained_exceptions
|
||||
from lama_cleaner.schema import Config, ModelType
|
||||
|
||||
|
||||
class SDXL(DiffusionInpaintModel):
|
||||
name = "sdxl"
|
||||
name = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
@@ -34,18 +35,19 @@ class SDXL(DiffusionInpaintModel):
|
||||
if os.path.isfile(self.model_id_or_path):
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
dtype=torch_dtype,
|
||||
num_in_channels=num_in_channels,
|
||||
)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
|
||||
)
|
||||
self.model = StableDiffusionXLInpaintPipeline.from_pretrained(
|
||||
self.model_id_or_path,
|
||||
revision="main",
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionXLInpaintPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
vae=vae,
|
||||
variant="fp16",
|
||||
)
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -16,8 +17,11 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
LCMScheduler
|
||||
LCMScheduler,
|
||||
)
|
||||
from huggingface_hub.utils import RevisionNotFoundError
|
||||
from loguru import logger
|
||||
from requests import HTTPError
|
||||
|
||||
from lama_cleaner.schema import SDSampler
|
||||
from torch import conv2d, conv_transpose2d
|
||||
@@ -944,3 +948,20 @@ def get_scheduler(sd_sampler, scheduler_config):
|
||||
return LCMScheduler.from_config(scheduler_config)
|
||||
else:
|
||||
raise ValueError(sd_sampler)
|
||||
|
||||
|
||||
def handle_from_pretrained_exceptions(func, **kwargs):
|
||||
try:
|
||||
return func(**kwargs)
|
||||
except ValueError as e:
|
||||
# 处理异常的逻辑
|
||||
if "You are trying to load the model files of the `variant=fp16`" in str(e):
|
||||
logger.info("variant=fp16 not found, try revision=fp16")
|
||||
return func(**{**kwargs, "variant": None, "revision": "fp16"})
|
||||
except OSError as e:
|
||||
previous_traceback = traceback.format_exc()
|
||||
if "RevisionNotFoundError: 404 Client Error." in previous_traceback:
|
||||
logger.info("revision=fp16 not found, try revision=main")
|
||||
return func(**{**kwargs, "variant": None, "revision": "main"})
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
100
lama_cleaner/model_info.py
Normal file
100
lama_cleaner/model_info.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from pydantic import computed_field, BaseModel
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
from lama_cleaner.model import InstructPix2Pix, Kandinsky22, PowerPaint, SD2
|
||||
from lama_cleaner.schema import ModelType
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def need_prompt(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
InstructPix2Pix.name,
|
||||
Kandinsky22.name,
|
||||
PowerPaint.name,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def controlnets(self) -> List[str]:
|
||||
if self.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return SDXL_CONTROLNET_CHOICES
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||
if self.name in [SD2.name]:
|
||||
return SD2_CONTROLNET_CHOICES
|
||||
else:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
if self.name == PowerPaint.name:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_outpainting(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [Kandinsky22.name, PowerPaint.name]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [PowerPaint.name]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [InstructPix2Pix.name]
|
||||
@@ -7,7 +7,8 @@ from lama_cleaner.download import scan_models
|
||||
from lama_cleaner.helper import switch_mps_device
|
||||
from lama_cleaner.model import models, ControlNet, SD, SDXL
|
||||
from lama_cleaner.model.utils import torch_gc
|
||||
from lama_cleaner.schema import Config, ModelInfo, ModelType
|
||||
from lama_cleaner.model_info import ModelInfo, ModelType
|
||||
from lama_cleaner.schema import Config
|
||||
|
||||
|
||||
class ModelManager:
|
||||
@@ -18,13 +19,20 @@ class ModelManager:
|
||||
self.available_models: Dict[str, ModelInfo] = {}
|
||||
self.scan_models()
|
||||
|
||||
self.sd_controlnet = False
|
||||
self.sd_controlnet_method = ""
|
||||
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
||||
controlnet_method = kwargs.get("controlnet_method", None)
|
||||
if (
|
||||
controlnet_method is None
|
||||
and name in self.available_models
|
||||
and self.available_models[name].support_controlnet
|
||||
):
|
||||
controlnet_method = self.available_models[name].controlnets[0]
|
||||
self.controlnet_method = controlnet_method
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
@property
|
||||
def current_model(self) -> Dict:
|
||||
return self.available_models[name].model_dump()
|
||||
return self.available_models[self.name].model_dump()
|
||||
|
||||
def init_model(self, name: str, device, **kwargs):
|
||||
logger.info(f"Loading model: {name}")
|
||||
@@ -35,15 +43,14 @@ class ModelManager:
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
"model_info": model_info,
|
||||
"sd_controlnet": self.sd_controlnet,
|
||||
"sd_controlnet_method": self.sd_controlnet_method,
|
||||
"enable_controlnet": self.enable_controlnet,
|
||||
"controlnet_method": self.controlnet_method,
|
||||
}
|
||||
|
||||
if model_info.model_type in [ModelType.INPAINT, ModelType.DIFFUSERS_OTHER]:
|
||||
return models[name](device, **kwargs)
|
||||
|
||||
if self.sd_controlnet:
|
||||
if model_info.support_controlnet and self.enable_controlnet:
|
||||
return ControlNet(device, **kwargs)
|
||||
elif model_info.name in models:
|
||||
return models[name](device, **kwargs)
|
||||
else:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
@@ -75,15 +82,15 @@ class ModelManager:
|
||||
return
|
||||
|
||||
old_name = self.name
|
||||
old_sd_controlnet_method = self.sd_controlnet_method
|
||||
old_controlnet_method = self.controlnet_method
|
||||
self.name = new_name
|
||||
|
||||
if (
|
||||
self.available_models[new_name].support_controlnet
|
||||
and self.sd_controlnet_method
|
||||
and self.controlnet_method
|
||||
not in self.available_models[new_name].controlnets
|
||||
):
|
||||
self.sd_controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
try:
|
||||
# TODO: enable/disable controlnet without reload model
|
||||
del self.model
|
||||
@@ -94,7 +101,7 @@ class ModelManager:
|
||||
)
|
||||
except Exception as e:
|
||||
self.name = old_name
|
||||
self.sd_controlnet_method = old_sd_controlnet_method
|
||||
self.controlnet_method = old_controlnet_method
|
||||
logger.info(f"Switch model from {old_name} to {new_name} failed, rollback")
|
||||
self.model = self.init_model(
|
||||
old_name, switch_mps_device(old_name, self.device), **self.kwargs
|
||||
@@ -106,24 +113,24 @@ class ModelManager:
|
||||
return
|
||||
|
||||
if (
|
||||
self.sd_controlnet
|
||||
self.enable_controlnet
|
||||
and config.controlnet_method
|
||||
and self.sd_controlnet_method != config.controlnet_method
|
||||
and self.controlnet_method != config.controlnet_method
|
||||
):
|
||||
old_sd_controlnet_method = self.sd_controlnet_method
|
||||
self.sd_controlnet_method = config.controlnet_method
|
||||
old_controlnet_method = self.controlnet_method
|
||||
self.controlnet_method = config.controlnet_method
|
||||
self.model.switch_controlnet_method(config.controlnet_method)
|
||||
logger.info(
|
||||
f"Switch Controlnet method from {old_sd_controlnet_method} to {config.controlnet_method}"
|
||||
f"Switch Controlnet method from {old_controlnet_method} to {config.controlnet_method}"
|
||||
)
|
||||
elif self.sd_controlnet != config.controlnet_enabled:
|
||||
self.sd_controlnet = config.controlnet_enabled
|
||||
self.sd_controlnet_method = config.controlnet_method
|
||||
elif self.enable_controlnet != config.enable_controlnet:
|
||||
self.enable_controlnet = config.enable_controlnet
|
||||
self.controlnet_method = config.controlnet_method
|
||||
|
||||
self.model = self.init_model(
|
||||
self.name, switch_mps_device(self.name, self.device), **self.kwargs
|
||||
)
|
||||
if not config.controlnet_enabled:
|
||||
if not config.enable_controlnet:
|
||||
logger.info(f"Disable controlnet")
|
||||
else:
|
||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||
|
||||
@@ -1,19 +1,8 @@
|
||||
from typing import Optional, List
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel, computed_field
|
||||
|
||||
from lama_cleaner.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
)
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@@ -25,103 +14,6 @@ class ModelType(str, Enum):
|
||||
DIFFUSERS_OTHER = "diffusers_other"
|
||||
|
||||
|
||||
FREEU_DEFAULT_CONFIGS = {
|
||||
ModelType.DIFFUSERS_SD: dict(s1=0.9, s2=0.2, b1=1.2, b2=1.4),
|
||||
ModelType.DIFFUSERS_SDXL: dict(s1=0.6, s2=0.4, b1=1.1, b2=1.2),
|
||||
}
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
name: str
|
||||
path: str
|
||||
model_type: ModelType
|
||||
is_single_file_diffusers: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def need_prompt(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
"timbrooks/instruct-pix2pix",
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def controlnets(self) -> List[str]:
|
||||
if self.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]:
|
||||
return SDXL_CONTROLNET_CHOICES
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT]:
|
||||
if self.name in ["stabilityai/stable-diffusion-2-inpainting"]:
|
||||
return SD2_CONTROLNET_CHOICES
|
||||
else:
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_outpainting(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint",
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_lcm_lora(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_controlnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
return (
|
||||
self.model_type
|
||||
in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
or "timbrooks/instruct-pix2pix" in self.name
|
||||
)
|
||||
|
||||
|
||||
class HDStrategy(str, Enum):
|
||||
# Use original image size
|
||||
ORIGINAL = "Original"
|
||||
@@ -157,6 +49,13 @@ class FREEUConfig(BaseModel):
|
||||
b2: float = 1.4
|
||||
|
||||
|
||||
class PowerPaintTask(str, Enum):
|
||||
text_guided = "text-guided"
|
||||
shape_guided = "shape-guided"
|
||||
object_remove = "object-remove"
|
||||
outpainting = "outpainting"
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -239,6 +138,11 @@ class Config(BaseModel):
|
||||
p2p_image_guidance_scale: float = 1.5
|
||||
|
||||
# ControlNet
|
||||
controlnet_enabled: bool = False
|
||||
enable_controlnet: bool = False
|
||||
controlnet_conditioning_scale: float = 0.4
|
||||
controlnet_method: str = "control_v11p_sd15_canny"
|
||||
controlnet_method: str = "lllyasviel/control_v11p_sd15_canny"
|
||||
|
||||
# PowerPaint
|
||||
powerpaint_task: PowerPaintTask = PowerPaintTask.text_guided
|
||||
# control the fitting degree of the generated objects to the mask shape.
|
||||
fitting_degree: float = 1.0
|
||||
|
||||
@@ -63,6 +63,7 @@ from lama_cleaner.helper import (
|
||||
numpy_to_bytes,
|
||||
resize_max_size,
|
||||
pil_to_bytes,
|
||||
is_mac,
|
||||
)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
@@ -285,9 +286,10 @@ def process():
|
||||
cv2_radius=form["cv2Radius"],
|
||||
paint_by_example_example_image=paint_by_example_example_image,
|
||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||
controlnet_enabled=form["controlnet_enabled"],
|
||||
enable_controlnet=form["enable_controlnet"],
|
||||
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
||||
controlnet_method=form["controlnet_method"],
|
||||
powerpaint_task=form["powerpaintTask"],
|
||||
)
|
||||
|
||||
if config.sd_seed == -1:
|
||||
@@ -305,6 +307,8 @@ def process():
|
||||
if "CUDA out of memory. " in str(e):
|
||||
# NOTE: the string may change?
|
||||
return "CUDA out of memory", 500
|
||||
elif "Invalid buffer size" in str(e) and is_mac():
|
||||
return "Out of memory", 500
|
||||
else:
|
||||
logger.exception(e)
|
||||
return f"{str(e)}", 500
|
||||
@@ -423,8 +427,8 @@ def get_server_config():
|
||||
"plugins": list(global_config.plugins.keys()),
|
||||
"enableFileManager": global_config.enable_file_manager,
|
||||
"enableAutoSaving": global_config.enable_auto_saving,
|
||||
"enableControlnet": global_config.model_manager.sd_controlnet,
|
||||
"controlnetMethod": global_config.model_manager.sd_controlnet_method,
|
||||
"enableControlnet": global_config.model_manager.enable_controlnet,
|
||||
"controlnetMethod": global_config.model_manager.controlnet_method,
|
||||
"disableModelSwitch": global_config.disable_model_switch,
|
||||
"isDesktop": global_config.is_desktop,
|
||||
}, 200
|
||||
|
||||
0
lama_cleaner/tests/utils.py
Normal file
0
lama_cleaner/tests/utils.py
Normal file
@@ -15,8 +15,8 @@ def save_config(
|
||||
port,
|
||||
model,
|
||||
sd_local_model_path,
|
||||
sd_controlnet,
|
||||
sd_controlnet_method,
|
||||
enable_controlnet,
|
||||
controlnet_method,
|
||||
device,
|
||||
gui,
|
||||
no_gui_auto_close,
|
||||
@@ -176,13 +176,13 @@ def main(config_file: str):
|
||||
sd_local_model_path = gr.Textbox(
|
||||
init_config.sd_local_model_path, label=f"{SD_LOCAL_MODEL_HELP}"
|
||||
)
|
||||
sd_controlnet = gr.Checkbox(
|
||||
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
||||
enable_controlnet = gr.Checkbox(
|
||||
init_config.enable_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
||||
)
|
||||
sd_controlnet_method = gr.Radio(
|
||||
controlnet_method = gr.Radio(
|
||||
SD_CONTROLNET_CHOICES,
|
||||
label="ControlNet method",
|
||||
value=init_config.sd_controlnet_method,
|
||||
value=init_config.controlnet_method,
|
||||
)
|
||||
no_half = gr.Checkbox(init_config.no_half, label=f"{NO_HALF_HELP}")
|
||||
cpu_offload = gr.Checkbox(
|
||||
@@ -205,8 +205,8 @@ def main(config_file: str):
|
||||
port,
|
||||
model,
|
||||
sd_local_model_path,
|
||||
sd_controlnet,
|
||||
sd_controlnet_method,
|
||||
enable_controlnet,
|
||||
controlnet_method,
|
||||
device,
|
||||
gui,
|
||||
no_gui_auto_close,
|
||||
|
||||
Reference in New Issue
Block a user