add AnyText
This commit is contained in:
@@ -5,20 +5,21 @@ Code: https://github.com/tyxsspa/AnyText
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from iopaint.model.utils import set_seed
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import torch
|
||||
import random
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import einops
|
||||
import time
|
||||
from PIL import ImageFont
|
||||
from iopaint.model.anytext.cldm.model import create_model, load_state_dict
|
||||
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler
|
||||
from iopaint.model.anytext.utils import (
|
||||
resize_image,
|
||||
check_channels,
|
||||
draw_glyph,
|
||||
draw_glyph2,
|
||||
@@ -29,55 +30,93 @@ BBOX_MAX_NUM = 8
|
||||
PLACE_HOLDER = "*"
|
||||
max_chars = 20
|
||||
|
||||
ANYTEXT_CFG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml"
|
||||
)
|
||||
|
||||
|
||||
def check_limits(tensor):
|
||||
float16_min = torch.finfo(torch.float16).min
|
||||
float16_max = torch.finfo(torch.float16).max
|
||||
|
||||
# 检查张量中是否有值小于float16的最小值或大于float16的最大值
|
||||
is_below_min = (tensor < float16_min).any()
|
||||
is_above_max = (tensor > float16_max).any()
|
||||
|
||||
return is_below_min or is_above_max
|
||||
|
||||
|
||||
class AnyTextPipeline:
|
||||
def __init__(self, cfg_path, model_dir, font_path, device, use_fp16=True):
|
||||
self.cfg_path = cfg_path
|
||||
self.model_dir = model_dir
|
||||
def __init__(self, ckpt_path, font_path, device, use_fp16=True):
|
||||
self.cfg_path = ANYTEXT_CFG
|
||||
self.font_path = font_path
|
||||
self.use_fp16 = use_fp16
|
||||
self.device = device
|
||||
self.init_model()
|
||||
|
||||
"""
|
||||
return:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
debug_info: string for debug, only valid if show_debug=True
|
||||
"""
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
self.model = create_model(
|
||||
self.cfg_path,
|
||||
device=self.device,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
if Path(ckpt_path).suffix == ".safetensors":
|
||||
state_dict = load_file(ckpt_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_state_dict(ckpt_path, location="cpu")
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
self.model = self.model.eval().to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
|
||||
def __call__(self, input_tensor, **forward_params):
|
||||
tic = time.time()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image: np.ndarray,
|
||||
masked_image: np.ndarray,
|
||||
num_inference_steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
height: int,
|
||||
width: int,
|
||||
seed: int,
|
||||
sort_priority: str = "y",
|
||||
callback=None,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt:
|
||||
negative_prompt:
|
||||
image:
|
||||
masked_image:
|
||||
num_inference_steps:
|
||||
strength:
|
||||
guidance_scale:
|
||||
height:
|
||||
width:
|
||||
seed:
|
||||
sort_priority: x: left-right, y: top-down
|
||||
|
||||
Returns:
|
||||
result: list of images in numpy.ndarray format
|
||||
rst_code: 0: normal -1: error 1:warning
|
||||
rst_info: string of error or warning
|
||||
|
||||
"""
|
||||
set_seed(seed)
|
||||
str_warning = ""
|
||||
# get inputs
|
||||
seed = input_tensor.get("seed", -1)
|
||||
if seed == -1:
|
||||
seed = random.randint(0, 99999999)
|
||||
# seed_everything(seed)
|
||||
prompt = input_tensor.get("prompt")
|
||||
draw_pos = input_tensor.get("draw_pos")
|
||||
ori_image = input_tensor.get("ori_image")
|
||||
|
||||
mode = forward_params.get("mode")
|
||||
sort_priority = forward_params.get("sort_priority", "↕")
|
||||
show_debug = forward_params.get("show_debug", False)
|
||||
revise_pos = forward_params.get("revise_pos", False)
|
||||
img_count = forward_params.get("image_count", 4)
|
||||
ddim_steps = forward_params.get("ddim_steps", 20)
|
||||
w = forward_params.get("image_width", 512)
|
||||
h = forward_params.get("image_height", 512)
|
||||
strength = forward_params.get("strength", 1.0)
|
||||
cfg_scale = forward_params.get("cfg_scale", 9.0)
|
||||
eta = forward_params.get("eta", 0.0)
|
||||
a_prompt = forward_params.get(
|
||||
"a_prompt",
|
||||
"best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks",
|
||||
)
|
||||
n_prompt = forward_params.get(
|
||||
"n_prompt",
|
||||
"low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture",
|
||||
)
|
||||
mode = "text-editing"
|
||||
revise_pos = False
|
||||
img_count = 1
|
||||
ddim_steps = num_inference_steps
|
||||
w = width
|
||||
h = height
|
||||
strength = strength
|
||||
cfg_scale = guidance_scale
|
||||
eta = 0.0
|
||||
|
||||
prompt, texts = self.modify_prompt(prompt)
|
||||
if prompt is None and texts is None:
|
||||
@@ -91,43 +130,44 @@ class AnyTextPipeline:
|
||||
if mode in ["text-generation", "gen"]:
|
||||
edit_image = np.ones((h, w, 3)) * 127.5 # empty mask image
|
||||
elif mode in ["text-editing", "edit"]:
|
||||
if draw_pos is None or ori_image is None:
|
||||
if masked_image is None or image is None:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
"Reference image and position image are needed for text editing!",
|
||||
"",
|
||||
)
|
||||
if isinstance(ori_image, str):
|
||||
ori_image = cv2.imread(ori_image)[..., ::-1]
|
||||
assert (
|
||||
ori_image is not None
|
||||
), f"Can't read ori_image image from{ori_image}!"
|
||||
elif isinstance(ori_image, torch.Tensor):
|
||||
ori_image = ori_image.cpu().numpy()
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)[..., ::-1]
|
||||
assert image is not None, f"Can't read ori_image image from{image}!"
|
||||
elif isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
ori_image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(ori_image)}"
|
||||
edit_image = ori_image.clip(1, 255) # for mask reason
|
||||
image, np.ndarray
|
||||
), f"Unknown format of ori_image: {type(image)}"
|
||||
edit_image = image.clip(1, 255) # for mask reason
|
||||
edit_image = check_channels(edit_image)
|
||||
edit_image = resize_image(
|
||||
edit_image, max_length=768
|
||||
) # make w h multiple of 64, resize if w or h > max_length
|
||||
# edit_image = resize_image(
|
||||
# edit_image, max_length=768
|
||||
# ) # make w h multiple of 64, resize if w or h > max_length
|
||||
h, w = edit_image.shape[:2] # change h, w by input ref_img
|
||||
# preprocess pos_imgs(if numpy, make sure it's white pos in black bg)
|
||||
if draw_pos is None:
|
||||
if masked_image is None:
|
||||
pos_imgs = np.zeros((w, h, 1))
|
||||
if isinstance(draw_pos, str):
|
||||
draw_pos = cv2.imread(draw_pos)[..., ::-1]
|
||||
assert draw_pos is not None, f"Can't read draw_pos image from{draw_pos}!"
|
||||
pos_imgs = 255 - draw_pos
|
||||
elif isinstance(draw_pos, torch.Tensor):
|
||||
pos_imgs = draw_pos.cpu().numpy()
|
||||
if isinstance(masked_image, str):
|
||||
masked_image = cv2.imread(masked_image)[..., ::-1]
|
||||
assert (
|
||||
masked_image is not None
|
||||
), f"Can't read draw_pos image from{masked_image}!"
|
||||
pos_imgs = 255 - masked_image
|
||||
elif isinstance(masked_image, torch.Tensor):
|
||||
pos_imgs = masked_image.cpu().numpy()
|
||||
else:
|
||||
assert isinstance(
|
||||
draw_pos, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(draw_pos)}"
|
||||
masked_image, np.ndarray
|
||||
), f"Unknown format of draw_pos: {type(masked_image)}"
|
||||
pos_imgs = 255 - masked_image
|
||||
pos_imgs = pos_imgs[..., 0:1]
|
||||
pos_imgs = cv2.convertScaleAbs(pos_imgs)
|
||||
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)
|
||||
@@ -139,11 +179,8 @@ class AnyTextPipeline:
|
||||
if n_lines == 1 and texts[0] == " ":
|
||||
pass # text-to-image without text
|
||||
else:
|
||||
return (
|
||||
None,
|
||||
-1,
|
||||
f"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!",
|
||||
"",
|
||||
raise RuntimeError(
|
||||
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images"
|
||||
)
|
||||
elif len(pos_imgs) > n_lines:
|
||||
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt."
|
||||
@@ -250,12 +287,16 @@ class AnyTextPipeline:
|
||||
cond = self.model.get_learned_conditioning(
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[prompt + " , " + a_prompt] * img_count],
|
||||
c_crossattn=[[prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
un_cond = self.model.get_learned_conditioning(
|
||||
dict(c_concat=[hint], c_crossattn=[[n_prompt] * img_count], text_info=info)
|
||||
dict(
|
||||
c_concat=[hint],
|
||||
c_crossattn=[[negative_prompt] * img_count],
|
||||
text_info=info,
|
||||
)
|
||||
)
|
||||
shape = (4, h // 8, w // 8)
|
||||
self.model.control_scales = [strength] * 13
|
||||
@@ -268,6 +309,7 @@ class AnyTextPipeline:
|
||||
eta=eta,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=un_cond,
|
||||
callback=callback
|
||||
)
|
||||
if self.use_fp16:
|
||||
samples = samples.half()
|
||||
@@ -280,52 +322,18 @@ class AnyTextPipeline:
|
||||
.astype(np.uint8)
|
||||
)
|
||||
results = [x_samples[i] for i in range(img_count)]
|
||||
if (
|
||||
mode == "edit" and False
|
||||
): # replace backgound in text editing but not ideal yet
|
||||
results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
if len(gly_pos_imgs) > 0 and show_debug:
|
||||
glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
# debug_info
|
||||
if not show_debug:
|
||||
debug_info = ""
|
||||
else:
|
||||
input_prompt = prompt
|
||||
for t in texts:
|
||||
input_prompt = input_prompt.replace("*", f'"{t}"', 1)
|
||||
debug_info = f'<span style="color:black;font-size:18px">Prompt: </span>{input_prompt}<br> \
|
||||
<span style="color:black;font-size:18px">Size: </span>{w}x{h}<br> \
|
||||
<span style="color:black;font-size:18px">Image Count: </span>{img_count}<br> \
|
||||
<span style="color:black;font-size:18px">Seed: </span>{seed}<br> \
|
||||
<span style="color:black;font-size:18px">Use FP16: </span>{self.use_fp16}<br> \
|
||||
<span style="color:black;font-size:18px">Cost Time: </span>{(time.time()-tic):.2f}s'
|
||||
# if (
|
||||
# mode == "edit" and False
|
||||
# ): # replace backgound in text editing but not ideal yet
|
||||
# results = [r * np_hint + edit_image * (1 - np_hint) for r in results]
|
||||
# results = [r.clip(0, 255).astype(np.uint8) for r in results]
|
||||
# if len(gly_pos_imgs) > 0 and show_debug:
|
||||
# glyph_bs = np.stack(gly_pos_imgs, axis=2)
|
||||
# glyph_img = np.sum(glyph_bs, axis=2) * 255
|
||||
# glyph_img = glyph_img.clip(0, 255).astype(np.uint8)
|
||||
# results += [np.repeat(glyph_img, 3, axis=2)]
|
||||
rst_code = 1 if str_warning else 0
|
||||
return results, rst_code, str_warning, debug_info
|
||||
|
||||
def init_model(self):
|
||||
font_path = self.font_path
|
||||
self.font = ImageFont.truetype(font_path, size=60)
|
||||
cfg_path = self.cfg_path
|
||||
ckpt_path = os.path.join(self.model_dir, "anytext_v1.1.ckpt")
|
||||
clip_path = os.path.join(self.model_dir, "clip-vit-large-patch14")
|
||||
self.model = create_model(
|
||||
cfg_path,
|
||||
device=self.device,
|
||||
cond_stage_path=clip_path,
|
||||
use_fp16=self.use_fp16,
|
||||
)
|
||||
if self.use_fp16:
|
||||
self.model = self.model.half()
|
||||
self.model.load_state_dict(
|
||||
load_state_dict(ckpt_path, location=self.device), strict=False
|
||||
)
|
||||
self.model.eval()
|
||||
self.model = self.model.to(self.device)
|
||||
self.ddim_sampler = DDIMSampler(self.model, device=self.device)
|
||||
return results, rst_code, str_warning
|
||||
|
||||
def modify_prompt(self, prompt):
|
||||
prompt = prompt.replace("“", '"')
|
||||
@@ -360,9 +368,9 @@ class AnyTextPipeline:
|
||||
component = np.zeros_like(img)
|
||||
component[labels == label] = 255
|
||||
components.append((component, centroids[label]))
|
||||
if sort_priority == "↕":
|
||||
if sort_priority == "y":
|
||||
fir, sec = 1, 0 # top-down first
|
||||
elif sort_priority == "↔":
|
||||
elif sort_priority == "x":
|
||||
fir, sec = 0, 1 # left-right first
|
||||
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))
|
||||
sorted_components = [c[0] for c in components]
|
||||
|
||||
Reference in New Issue
Block a user