add AnyText
This commit is contained in:
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
INSTRUCT_PIX2PIX_NAME = "timbrooks/instruct-pix2pix"
|
||||
KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
ANYTEXT_NAME = "Sanster/AnyText"
|
||||
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
|
||||
@@ -12,6 +12,7 @@ from iopaint.const import (
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_CLASS_NAME,
|
||||
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.model_info import ModelInfo, ModelType
|
||||
|
||||
@@ -24,6 +25,10 @@ def cli_download_model(model: str):
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
elif model == ANYTEXT_NAME:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
else:
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -210,6 +215,7 @@ def scan_models() -> List[ModelInfo]:
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"AnyText",
|
||||
]:
|
||||
model_type = ModelType.DIFFUSERS_OTHER
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .anytext.anytext_model import AnyText
|
||||
from .controlnet import ControlNet
|
||||
from .fcf import FcF
|
||||
from .instruct_pix2pix import InstructPix2Pix
|
||||
@@ -32,4 +33,5 @@ models = {
|
||||
Kandinsky22.name: Kandinsky22,
|
||||
SDXL.name: SDXL,
|
||||
PowerPaint.name: PowerPaint,
|
||||
AnyText.name: AnyText,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from iopaint.const import ANYTEXT_NAME
|
||||
from iopaint.model.anytext.anytext_pipeline import AnyTextPipeline
|
||||
from iopaint.model.base import DiffusionInpaintModel
|
||||
from iopaint.model.utils import get_torch_dtype, is_local_files_only
|
||||
from iopaint.schema import InpaintRequest
|
||||
|
||||
|
||||
class AnyText(DiffusionInpaintModel):
|
||||
name = ANYTEXT_NAME
|
||||
pad_mod = 64
|
||||
is_erase_model = False
|
||||
|
||||
@staticmethod
|
||||
def download(local_files_only=False):
|
||||
hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="model_index.json",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
ckpt_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="pytorch_model.fp16.safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
font_path = hf_hub_download(
|
||||
repo_id=ANYTEXT_NAME,
|
||||
filename="SourceHanSansSC-Medium.otf",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return ckpt_path, font_path
|
||||
|
||||
def init_model(self, device, **kwargs):
|
||||
local_files_only = is_local_files_only(**kwargs)
|
||||
ckpt_path, font_path = self.download(local_files_only)
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.model = AnyTextPipeline(
|
||||
ckpt_path=ckpt_path,
|
||||
font_path=font_path,
|
||||
device=device,
|
||||
use_fp16=torch_dtype == torch.float16,
|
||||
)
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to inpainting
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
mask = mask.astype("float32") / 255.0
|
||||
masked_image = image * (1 - mask)
|
||||
|
||||
# list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = self.model(
|
||||
image=image,
|
||||
masked_image=masked_image,
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
num_inference_steps=config.sd_steps,
|
||||
strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
height=height,
|
||||
width=width,
|
||||
seed=config.sd_seed,
|
||||
sort_priority="y",
|
||||
callback=self.callback
|
||||
)
|
||||
inpainted_rgb_image = results[0][..., ::-1]
|
||||
return inpainted_rgb_image
|
||||
|
||||
@@ -5,20 +5,21 @@ Code: https://github.com/tyxsspa/AnyText
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.model.utils import set_seed
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import einops
|
||||
import time
|
||||
from PIL import ImageFont
|
||||
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
||||
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
||||
from iopaint.model.anytext.utils import (
|
||||
resize_image,
|
||||
check_channels,
|
||||
draw_glyph,
|
||||
draw_glyph2,
|
||||
@@ -29,55 +30,93 @@ BBOX_MAX_NUM = 8
|
||||
PLACE_HOLDER = "*"
|
||||
max_chars = 20
|
||||
|
||||
ANYTEXT_CFG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
||||
)
|
||||
|
||||
|
||||
def check_limits(tensor):
|
||||
float16_min = torch.finfo(torch.float16).min
|
||||
float16_max = torch.finfo(torch.float16).max
|
||||
|
||||
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
||||
is_below_min = (tensor < float16_min).any()
|
||||
is_above_max = (tensor > float16_max).any()
|
||||
|
||||
return is_below_min or is_above_max
|
||||
|
||||
|
||||
class AnyTextPipeline:
|
||||
def __init__(self, cfg_path, model_dir, font_path, device, use_fp16=True):
|
||||
self.cfg_path = cfg_path
|
||||
self.model_dir = model_dir
|
||||
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
||||
self.cfg_path = ANYTEXT_CFG
|
||||
self.font_path = font_path
|
||||
self.use_fp16 = use_fp16
|
||||
self.device = device
|
||||
self.init_model()
|
||||
|
||||
"""
|
||||
return:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
debug_info: string for debug, only valid if show_debug=True
|
||||
"""
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
self.model = create_model(
|
||||
self.cfg_path,
|
||||
device=self.device,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
if Path(ckpt_path).suffix == ".safetensors":
|
||||
state_dict = load_file(ckpt_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_state_dict(ckpt_path, location="cpu")
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model = self.model.eval().to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
|
||||
def __call__(self, input_tensor, **forward_params):
|
||||
tic = time.time()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: np.ndarray,
|
||||
masked_image: np.ndarray,
|
||||
num_inference_steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
sort_priority: str = "y",
|
||||
callback=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt:
|
||||
negative_prompt:
|
||||
image:
|
||||
masked_image:
|
||||
num_inference_steps:
|
||||
strength:
|
||||
guidance_scale:
|
||||
height:
|
||||
width:
|
||||
seed:
|
||||
sort_priority: x: left-right, y: top-down
|
||||
|
||||
Returns:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
|
||||
"""
|
||||
set_seed(seed)
|
||||
str_warning = ""
|
||||
# get inputs
|
||||
seed = input_tensor.get("seed", -1)
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999999)
|
||||
# seed_everything(seed)
|
||||
prompt = input_tensor.get("prompt")
|
||||
draw_pos = input_tensor.get("draw_pos")
|
||||
ori_image = input_tensor.get("ori_image")
|
||||
|
||||
mode = forward_params.get("mode")
|
||||
sort_priority = forward_params.get("sort_priority", "↕")
|
||||
show_debug = forward_params.get("show_debug", False)
|
||||
revise_pos = forward_params.get("revise_pos", False)
|
||||
img_count = forward_params.get("image_count", 4)
|
||||
ddim_steps = forward_params.get("ddim_steps", 20)
|
||||
w = forward_params.get("image_width", 512)
|
||||
h = forward_params.get("image_height", 512)
|
||||
strength = forward_params.get("strength", 1.0)
|
||||
cfg_scale = forward_params.get("cfg_scale", 9.0)
|
||||
eta = forward_params.get("eta", 0.0)
|
||||
a_prompt = forward_params.get(
|
||||
"a_prompt",
|
||||
"best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks",
|
||||
)
|
||||
n_prompt = forward_params.get(
|
||||
"n_prompt",
|
||||
"low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
)
|
||||
mode = "text-editing"
|
||||
revise_pos = False
|
||||
img_count = 1
|
||||
ddim_steps = num_inference_steps
|
||||
w = width
|
||||
h = height
|
||||
strength = strength
|
||||
cfg_scale = guidance_scale
|
||||
eta = 0.0
|
||||
|
||||
prompt, texts = self.modify_prompt(prompt)
|
||||
if prompt is None and texts is None:
|
||||
@@ -91,43 +130,44 @@ class AnyTextPipeline:
|
||||
if mode in ["text-generation", "gen"]:
|
||||
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
||||
elif mode in ["text-editing", "edit"]:
|
||||
if draw_pos is None or ori_image is None:
|
||||
if masked_image is None or image is None:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
"Reference image and position image are needed for text editing!",
|
||||
"",
|
||||
)
|
||||
if isinstance(ori_image, str):
|
||||
ori_image = cv2.imread(ori_image)[..., ::-1]
|
||||
assert (
|
||||
ori_image is not None
|
||||
), f"Can't read ori_image image from{ori_image}!"
|
||||
elif isinstance(ori_image, torch.Tensor):
|
||||
ori_image = ori_image.cpu().numpy()
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)[..., ::-1]
|
||||
assert image is not None, f"Can't read ori_image image from{image}!"
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
ori_image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(ori_image)}"
|
||||
edit_image = ori_image.clip(1, 255) # for mask reason
|
||||
image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(image)}"
|
||||
edit_image = image.clip(1, 255) # for mask reason
|
||||
edit_image = check_channels(edit_image)
|
||||
edit_image = resize_image(
|
||||
edit_image, max_length=768
|
||||
) # make w h multiple of 64, resize if w or h > max_length
|
||||
# edit_image = resize_image(
|
||||
# edit_image, max_length=768
|
||||
# ) # make w h multiple of 64, resize if w or h > max_length
|
||||
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
||||
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
||||
if draw_pos is None:
|
||||
if masked_image is None:
|
||||
pos_imgs = np.zeros((w, h, 1))
|
||||
if isinstance(draw_pos, str):
|
||||
draw_pos = cv2.imread(draw_pos)[..., ::-1]
|
||||
assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!"
|
||||
pos_imgs = 255 - draw_pos
|
||||
elif isinstance(draw_pos, torch.Tensor):
|
||||
pos_imgs = draw_pos.cpu().numpy()
|
||||
if isinstance(masked_image, str):
|
||||
masked_image = cv2.imread(masked_image)[..., ::-1]
|
||||
assert (
|
||||
masked_image is not None
|
||||
), f"Can't read draw_pos image from{masked_image}!"
|
||||
pos_imgs = 255 - masked_image
|
||||
elif isinstance(masked_image, torch.Tensor):
|
||||
pos_imgs = masked_image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
draw_pos, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(draw_pos)}"
|
||||
masked_image, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(masked_image)}"
|
||||
pos_imgs = 255 - masked_image
|
||||
pos_imgs = pos_imgs[..., 0:1]
|
||||
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
||||
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
||||
@@ -139,11 +179,8 @@ class AnyTextPipeline:
|
||||
if n_lines == 1 and texts[0] == " ":
|
||||
pass # text-to-image without text
|
||||
else:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!",
|
||||
"",
|
||||
raise RuntimeError(
|
||||
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
||||
)
|
||||
elif len(pos_imgs) > n_lines:
|
||||
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
||||
@@ -250,12 +287,16 @@ class AnyTextPipeline:
|
||||
cond = self.model.get_learned_conditioning(
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[prompt + " , " + a_prompt] * img_count],
|
||||
c_crossattn=[[prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
un_cond = self.model.get_learned_conditioning(
|
||||
dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info)
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[negative_prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
shape = (4, h // 8, w // 8)
|
||||
self.model.control_scales = [strength] * 13
|
||||
@@ -268,6 +309,7 @@ class AnyTextPipeline:
|
||||
eta=eta,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=un_cond,
|
||||
callback=callback
|
||||
)
|
||||
if self.use_fp16:
|
||||
samples = samples.half()
|
||||
@@ -280,52 +322,18 @@ class AnyTextPipeline:
|
||||
.astype(np.uint8)
|
||||
)
|
||||
results = [x_samples[i] for i in range(img_count)]
|
||||
if (
|
||||
mode == "edit" and False
|
||||
): # replace backgound in text editing but not ideal yet
|
||||
results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
if len(gly_pos_imgs) > 0 and show_debug:
|
||||
glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
# debug_info
|
||||
if not show_debug:
|
||||
debug_info = ""
|
||||
else:
|
||||
input_prompt = prompt
|
||||
for t in texts:
|
||||
input_prompt = input_prompt.replace("*", f'"{t}"', 1)
|
||||
debug_info = f'<span style="color:black;font-size:18px">Prompt: </span>{input_prompt}<br> \
|
||||
<span style="color:black;font-size:18px">Size: </span>{w}x{h}<br> \
|
||||
<span style="color:black;font-size:18px">Image Count: </span>{img_count}<br> \
|
||||
<span style="color:black;font-size:18px">Seed: </span>{seed}<br> \
|
||||
<span style="color:black;font-size:18px">Use FP16: </span>{self.use_fp16}<br> \
|
||||
<span style="color:black;font-size:18px">Cost Time: </span>{(time.time()-tic):.2f}s'
|
||||
# if (
|
||||
# mode == "edit" and False
|
||||
# ): # replace backgound in text editing but not ideal yet
|
||||
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
# if len(gly_pos_imgs) > 0 and show_debug:
|
||||
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
# results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
rst_code = 1 if str_warning else 0
|
||||
return results, rst_code, str_warning, debug_info
|
||||
|
||||
def init_model(self):
|
||||
font_path = self.font_path
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
cfg_path = self.cfg_path
|
||||
ckpt_path = os.path.join(self.model_dir, "anytext_v1.1.ckpt")
|
||||
clip_path = os.path.join(self.model_dir, "clip-vit-large-patch14")
|
||||
self.model = create_model(
|
||||
cfg_path,
|
||||
device=self.device,
|
||||
cond_stage_path=clip_path,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
self.model.load_state_dict(
|
||||
load_state_dict(ckpt_path, location=self.device), strict=False
|
||||
)
|
||||
self.model.eval()
|
||||
self.model = self.model.to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
return results, rst_code, str_warning
|
||||
|
||||
def modify_prompt(self, prompt):
|
||||
prompt = prompt.replace("“", '"')
|
||||
@@ -360,9 +368,9 @@ class AnyTextPipeline:
|
||||
component = np.zeros_like(img)
|
||||
component[labels == label] = 255
|
||||
components.append((component, centroids[label]))
|
||||
if sort_priority == "↕":
|
||||
if sort_priority == "y":
|
||||
fir, sec = 1, 0 # top-down first
|
||||
elif sort_priority == "↔":
|
||||
elif sort_priority == "x":
|
||||
fir, sec = 0, 1 # left-right first
|
||||
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
||||
sorted_components = [c[0] for c in components]
|
||||
|
||||
@@ -95,5 +95,5 @@ model:
|
||||
cond_stage_config:
|
||||
target: iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedderT3
|
||||
params:
|
||||
version: ./models/clip-vit-large-patch14
|
||||
version: openai/clip-vit-large-patch14
|
||||
use_vision: false # v6
|
||||
|
||||
@@ -254,7 +254,7 @@ class DDIMSampler(object):
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
callback(None, i, None, None)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
|
||||
@@ -26,11 +26,11 @@ def load_state_dict(ckpt_path, location="cpu"):
|
||||
|
||||
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
|
||||
config = OmegaConf.load(config_path)
|
||||
if cond_stage_path:
|
||||
config.model.params.cond_stage_config.params.version = (
|
||||
cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
)
|
||||
config.model.params.cond_stage_config.params.device = device
|
||||
# if cond_stage_path:
|
||||
# config.model.params.cond_stage_config.params.version = (
|
||||
# cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
# )
|
||||
config.model.params.cond_stage_config.params.device = str(device)
|
||||
if use_fp16:
|
||||
config.model.params.use_fp16 = True
|
||||
config.model.params.control_stage_config.params.use_fp16 = True
|
||||
|
||||
@@ -2,7 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import (
|
||||
T5Tokenizer,
|
||||
T5EncoderModel,
|
||||
CLIPTokenizer,
|
||||
CLIPTextModel,
|
||||
AutoProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from iopaint.model.anytext.ldm.util import count_params
|
||||
|
||||
@@ -18,7 +25,9 @@ def _expand_mask(mask, dtype, tgt_len=None):
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
@@ -30,6 +39,7 @@ def _build_causal_attention_mask(bsz, seq_len, dtype):
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -39,13 +49,12 @@ class AbstractEncoder(nn.Module):
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
@@ -57,15 +66,17 @@ class ClassEmbedder(nn.Module):
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
if self.ucg_rate > 0.0 and not disable_dropout:
|
||||
mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc_class = (
|
||||
self.n_classes - 1
|
||||
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs,), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
@@ -79,24 +90,34 @@ def disabled_train(self, mode=True):
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
@@ -109,13 +130,18 @@ class FrozenT5Embedder(AbstractEncoder):
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None,
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
@@ -137,10 +163,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(
|
||||
input_ids=tokens, output_hidden_states=self.layer == "hidden"
|
||||
)
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
@@ -153,77 +188,24 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
def __init__(
|
||||
self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77,
|
||||
):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.clip_encoder = FrozenCLIPEmbedder(
|
||||
clip_version, device, max_length=clip_max_length
|
||||
)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
@@ -236,7 +218,15 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, freeze=True, use_vision=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
use_vision=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
@@ -255,7 +245,11 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
inputs_embeds=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
||||
seq_length = (
|
||||
input_ids.shape[-1]
|
||||
if input_ids is not None
|
||||
else inputs_embeds.shape[-2]
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
if inputs_embeds is None:
|
||||
@@ -266,7 +260,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
return embeddings
|
||||
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
|
||||
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
|
||||
self.transformer.text_model.embeddings
|
||||
)
|
||||
|
||||
def encoder_forward(
|
||||
self,
|
||||
@@ -277,11 +273,19 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
hidden_states = inputs_embeds
|
||||
@@ -301,7 +305,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
return hidden_states
|
||||
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
|
||||
self.transformer.text_model.encoder.forward = encoder_forward.__get__(
|
||||
self.transformer.text_model.encoder
|
||||
)
|
||||
|
||||
def text_encoder_forward(
|
||||
self,
|
||||
@@ -313,22 +319,34 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
return_dict=None,
|
||||
embedding_manager=None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
bsz, seq_len = input_shape
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||
hidden_states.device
|
||||
)
|
||||
causal_attention_mask = _build_causal_attention_mask(
|
||||
bsz, seq_len, hidden_states.dtype
|
||||
).to(hidden_states.device)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
@@ -344,7 +362,9 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
return last_hidden_state
|
||||
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
|
||||
self.transformer.text_model.forward = text_encoder_forward.__get__(
|
||||
self.transformer.text_model
|
||||
)
|
||||
|
||||
def transformer_forward(
|
||||
self,
|
||||
@@ -363,7 +383,7 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
embedding_manager=embedding_manager
|
||||
embedding_manager=embedding_manager,
|
||||
)
|
||||
|
||||
self.transformer.forward = transformer_forward.__get__(self.transformer)
|
||||
@@ -374,8 +394,15 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
z = self.transformer(input_ids=tokens, **kwargs)
|
||||
return z
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from anytext_pipeline import AnyTextPipeline
|
||||
from utils import save_images
|
||||
|
||||
@@ -5,48 +8,38 @@ seed = 66273235
|
||||
# seed_everything(seed)
|
||||
|
||||
pipe = AnyTextPipeline(
|
||||
cfg_path="/Users/cwq/code/github/AnyText/anytext/models_yaMl/anytext_sd15.yaml",
|
||||
model_dir="/Users/cwq/.cache/modelscope/hub/damo/cv_anytext_text_generation_editing",
|
||||
# font_path="/Users/cwq/code/github/AnyText/anytext/font/Arial_Unicode.ttf",
|
||||
# font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-VF.ttf",
|
||||
ckpt_path="/Users/cwq/code/github/IOPaint/iopaint/model/anytext/anytext_v1.1_fp16.ckpt",
|
||||
font_path="/Users/cwq/code/github/AnyText/anytext/font/SourceHanSansSC-Medium.otf",
|
||||
use_fp16=False,
|
||||
device="mps",
|
||||
)
|
||||
|
||||
img_save_folder = "SaveImages"
|
||||
params = {
|
||||
"show_debug": True,
|
||||
"image_count": 2,
|
||||
"ddim_steps": 20,
|
||||
}
|
||||
rgb_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg"
|
||||
)[..., ::-1]
|
||||
|
||||
# # 1. text generation
|
||||
# mode = "text-generation"
|
||||
# input_data = {
|
||||
# "prompt": 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream',
|
||||
# "seed": seed,
|
||||
# "draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/gen9.png",
|
||||
# }
|
||||
# results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
|
||||
# if rtn_code >= 0:
|
||||
# save_images(results, img_save_folder)
|
||||
# print(f"Done, result images are saved in: {img_save_folder}")
|
||||
# if rtn_warning:
|
||||
# print(rtn_warning)
|
||||
#
|
||||
# exit()
|
||||
# 2. text editing
|
||||
mode = "text-editing"
|
||||
input_data = {
|
||||
"prompt": 'A cake with colorful characters that reads "EVERYDAY"',
|
||||
"seed": seed,
|
||||
"draw_pos": "/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png",
|
||||
"ori_image": "/Users/cwq/code/github/AnyText/anytext/example_images/ref7.jpg",
|
||||
}
|
||||
results, rtn_code, rtn_warning, debug_info = pipe(input_data, mode=mode, **params)
|
||||
masked_image = cv2.imread(
|
||||
"/Users/cwq/code/github/AnyText/anytext/example_images/edit7.png"
|
||||
)[..., ::-1]
|
||||
|
||||
rgb_image = cv2.resize(rgb_image, (512, 512))
|
||||
masked_image = cv2.resize(masked_image, (512, 512))
|
||||
|
||||
# results: list of rgb ndarray
|
||||
results, rtn_code, rtn_warning = pipe(
|
||||
prompt='A cake with colorful characters that reads "EVERYDAY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
image=rgb_image,
|
||||
masked_image=masked_image,
|
||||
num_inference_steps=20,
|
||||
strength=1.0,
|
||||
guidance_scale=9.0,
|
||||
height=rgb_image.shape[0],
|
||||
width=rgb_image.shape[1],
|
||||
seed=seed,
|
||||
sort_priority="y",
|
||||
)
|
||||
if rtn_code >= 0:
|
||||
save_images(results, img_save_folder)
|
||||
print(f"Done, result images are saved in: {img_save_folder}")
|
||||
if rtn_warning:
|
||||
print(rtn_warning)
|
||||
|
||||
@@ -9,6 +9,7 @@ from iopaint.const import (
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
)
|
||||
from iopaint.schema import ModelType
|
||||
|
||||
@@ -31,6 +32,7 @@ class ModelInfo(BaseModel):
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
POWERPAINT_NAME,
|
||||
ANYTEXT_NAME,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@@ -58,7 +60,7 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
] or self.name in [POWERPAINT_NAME]
|
||||
] or self.name in [POWERPAINT_NAME, ANYTEXT_NAME]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
||||
BIN
iopaint/tests/anytext_mask.jpg
Normal file
BIN
iopaint/tests/anytext_mask.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.7 KiB |
BIN
iopaint/tests/anytext_ref.jpg
Normal file
BIN
iopaint/tests/anytext_ref.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 104 KiB |
45
iopaint/tests/test_anytext.py
Normal file
45
iopaint/tests/test_anytext.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
def test_anytext(device):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="Sanster/AnyText",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt='Characters written in chalk on the blackboard that says "DADDY", best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks',
|
||||
negative_prompt="low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=9.0,
|
||||
sd_seed=66273235,
|
||||
sd_match_histograms=True
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"anytext.png",
|
||||
img_p=current_dir / "anytext_ref.jpg",
|
||||
mask_p=current_dir / "anytext_mask.jpg",
|
||||
)
|
||||
Reference in New Issue
Block a user