anytext init
This commit is contained in:
630
iopaint/model/anytext/cldm/cldm.py
Normal file
630
iopaint/model/anytext/cldm/cldm.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import copy
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
||||
conv_nd,
|
||||
linear,
|
||||
zero_module,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from iopaint.model.anytext.ldm.modules.attention import SpatialTransformer
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
||||
from iopaint.model.anytext.ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from iopaint.model.anytext.ldm.util import log_txt_as_img, exists, instantiate_from_config
|
||||
from iopaint.model.anytext.ldm.models.diffusion.ddim import DDIMSampler
|
||||
from iopaint.model.anytext.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from .recognizer import TextRecognizer, create_predictor
|
||||
|
||||
CURRENT_DIR = Path(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
||||
hs = []
|
||||
with torch.no_grad():
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
if self.use_fp16:
|
||||
t_emb = t_emb.half()
|
||||
emb = self.time_embed(t_emb)
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
if only_mid_control or control is None:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
else:
|
||||
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
glyph_channels,
|
||||
position_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_fp16 = use_fp16
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.glyph_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, glyph_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, position_channels, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 8, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
if self.use_fp16:
|
||||
t_emb = t_emb.half()
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
# guided_hint from text_info
|
||||
B, C, H, W = x.shape
|
||||
glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
|
||||
positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
|
||||
enc_glyph = self.glyph_block(glyphs, emb, context)
|
||||
enc_pos = self.position_block(positions, emb, context)
|
||||
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
|
||||
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
h += guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h = module(h, emb, context)
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
class ControlLDM(LatentDiffusion):
|
||||
|
||||
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
|
||||
self.use_fp16 = kwargs.pop('use_fp16', False)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.control_model = instantiate_from_config(control_stage_config)
|
||||
self.control_key = control_key
|
||||
self.glyph_key = glyph_key
|
||||
self.position_key = position_key
|
||||
self.only_mid_control = only_mid_control
|
||||
self.control_scales = [1.0] * 13
|
||||
self.loss_alpha = loss_alpha
|
||||
self.loss_beta = loss_beta
|
||||
self.with_step_weight = with_step_weight
|
||||
self.use_vae_upsample = use_vae_upsample
|
||||
self.latin_weight = latin_weight
|
||||
|
||||
if embedding_manager_config is not None and embedding_manager_config.params.valid:
|
||||
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
|
||||
for param in self.embedding_manager.embedding_parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
self.embedding_manager = None
|
||||
if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
|
||||
if embedding_manager_config.params.emb_type == 'ocr':
|
||||
self.text_predictor = create_predictor().eval()
|
||||
args = edict()
|
||||
args.rec_image_shape = "3, 48, 320"
|
||||
args.rec_batch_num = 6
|
||||
args.rec_char_dict_path = str(CURRENT_DIR.parent / "ocr_recog" / "ppocr_keys_v1.txt")
|
||||
args.use_fp16 = self.use_fp16
|
||||
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
|
||||
for param in self.text_predictor.parameters():
|
||||
param.requires_grad = False
|
||||
if self.embedding_manager:
|
||||
self.embedding_manager.recog = self.cn_recognizer
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
||||
if self.embedding_manager is None: # fill in full caption
|
||||
self.fill_caption(batch)
|
||||
x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
|
||||
control = batch[self.control_key] # for log_images and loss_alpha, not real control
|
||||
if bs is not None:
|
||||
control = control[:bs]
|
||||
control = control.to(self.device)
|
||||
control = einops.rearrange(control, 'b h w c -> b c h w')
|
||||
control = control.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
inv_mask = batch['inv_mask']
|
||||
if bs is not None:
|
||||
inv_mask = inv_mask[:bs]
|
||||
inv_mask = inv_mask.to(self.device)
|
||||
inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
|
||||
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
glyphs = batch[self.glyph_key]
|
||||
gly_line = batch['gly_line']
|
||||
positions = batch[self.position_key]
|
||||
n_lines = batch['n_lines']
|
||||
language = batch['language']
|
||||
texts = batch['texts']
|
||||
assert len(glyphs) == len(positions)
|
||||
for i in range(len(glyphs)):
|
||||
if bs is not None:
|
||||
glyphs[i] = glyphs[i][:bs]
|
||||
gly_line[i] = gly_line[i][:bs]
|
||||
positions[i] = positions[i][:bs]
|
||||
n_lines = n_lines[:bs]
|
||||
glyphs[i] = glyphs[i].to(self.device)
|
||||
gly_line[i] = gly_line[i].to(self.device)
|
||||
positions[i] = positions[i].to(self.device)
|
||||
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
|
||||
gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
|
||||
positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
|
||||
glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
|
||||
gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
|
||||
positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
|
||||
info = {}
|
||||
info['glyphs'] = glyphs
|
||||
info['positions'] = positions
|
||||
info['n_lines'] = n_lines
|
||||
info['language'] = language
|
||||
info['texts'] = texts
|
||||
info['img'] = batch['img'] # nhwc, (-1,1)
|
||||
info['masked_x'] = mx
|
||||
info['gly_line'] = gly_line
|
||||
info['inv_mask'] = inv_mask
|
||||
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
||||
assert isinstance(cond, dict)
|
||||
diffusion_model = self.model.diffusion_model
|
||||
_cond = torch.cat(cond['c_crossattn'], 1)
|
||||
_hint = torch.cat(cond['c_concat'], 1)
|
||||
if self.use_fp16:
|
||||
x_noisy = x_noisy.half()
|
||||
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
|
||||
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
||||
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
|
||||
|
||||
return eps
|
||||
|
||||
def instantiate_embedding_manager(self, config, embedder):
|
||||
model = instantiate_from_config(config, embedder=embedder)
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, N):
|
||||
return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
|
||||
if self.embedding_manager is not None and c['text_info'] is not None:
|
||||
self.embedding_manager.encode_text(c['text_info'])
|
||||
if isinstance(c, dict):
|
||||
cond_txt = c['c_crossattn'][0]
|
||||
else:
|
||||
cond_txt = c
|
||||
if self.embedding_manager is not None:
|
||||
cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
|
||||
else:
|
||||
cond_txt = self.cond_stage_model.encode(cond_txt)
|
||||
if isinstance(c, dict):
|
||||
c['c_crossattn'][0] = cond_txt
|
||||
else:
|
||||
c = cond_txt
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
c = self.cond_stage_model(c)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
return c
|
||||
|
||||
def fill_caption(self, batch, place_holder='*'):
|
||||
bs = len(batch['n_lines'])
|
||||
cond_list = copy.deepcopy(batch[self.cond_stage_key])
|
||||
for i in range(bs):
|
||||
n_lines = batch['n_lines'][i]
|
||||
if n_lines == 0:
|
||||
continue
|
||||
cur_cap = cond_list[i]
|
||||
for j in range(n_lines):
|
||||
r_txt = batch['texts'][j][i]
|
||||
cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
|
||||
cond_list[i] = cur_cap
|
||||
batch[self.cond_stage_key] = cond_list
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
||||
use_ema_scope=True,
|
||||
**kwargs):
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
||||
if self.cond_stage_trainable:
|
||||
with torch.no_grad():
|
||||
c = self.get_learned_conditioning(c)
|
||||
c_crossattn = c["c_crossattn"][0][:N]
|
||||
c_cat = c["c_concat"][0][:N]
|
||||
text_info = c["text_info"]
|
||||
text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
|
||||
text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
|
||||
text_info['positions'] = [i[:N] for i in text_info['positions']]
|
||||
text_info['n_lines'] = text_info['n_lines'][:N]
|
||||
text_info['masked_x'] = text_info['masked_x'][:N]
|
||||
text_info['img'] = text_info['img'][:N]
|
||||
|
||||
N = min(z.shape[0], N)
|
||||
n_row = min(z.shape[0], n_row)
|
||||
log["reconstruction"] = self.decode_first_stage(z)
|
||||
log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
|
||||
log["control"] = c_cat * 2.0 - 1.0
|
||||
log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
|
||||
# get glyph
|
||||
glyph_bs = torch.stack(text_info['glyphs'])
|
||||
glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
|
||||
log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
|
||||
# fill caption
|
||||
if not self.embedding_manager:
|
||||
self.fill_caption(batch)
|
||||
captions = batch[self.cond_stage_key]
|
||||
log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
||||
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
||||
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log["diffusion_row"] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
|
||||
batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc_cross = self.get_unconditional_conditioning(N)
|
||||
uc_cat = c_cat # torch.zeros_like(c_cat)
|
||||
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
|
||||
samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
|
||||
batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
pred_x0 = False # wether log pred_x0
|
||||
if pred_x0:
|
||||
for idx in range(len(tmps['pred_x0'])):
|
||||
pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
|
||||
log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
|
||||
|
||||
return log
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
b, c, h, w = cond["c_concat"][0].shape
|
||||
shape = (self.channels, h // 8, w // 8)
|
||||
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
|
||||
return samples, intermediates
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.control_model.parameters())
|
||||
if self.embedding_manager:
|
||||
params += list(self.embedding_manager.embedding_parameters())
|
||||
if not self.sd_locked:
|
||||
# params += list(self.model.diffusion_model.input_blocks.parameters())
|
||||
# params += list(self.model.diffusion_model.middle_block.parameters())
|
||||
params += list(self.model.diffusion_model.output_blocks.parameters())
|
||||
params += list(self.model.diffusion_model.out.parameters())
|
||||
if self.unlockKV:
|
||||
nCount = 0
|
||||
for name, param in self.model.diffusion_model.named_parameters():
|
||||
if 'attn2.to_k' in name or 'attn2.to_v' in name:
|
||||
params += [param]
|
||||
nCount += 1
|
||||
print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
|
||||
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
return opt
|
||||
|
||||
def low_vram_shift(self, is_diffusing):
|
||||
if is_diffusing:
|
||||
self.model = self.model.cuda()
|
||||
self.control_model = self.control_model.cuda()
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.cond_stage_model = self.cond_stage_model.cpu()
|
||||
else:
|
||||
self.model = self.model.cpu()
|
||||
self.control_model = self.control_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.cuda()
|
||||
self.cond_stage_model = self.cond_stage_model.cuda()
|
||||
486
iopaint/model/anytext/cldm/ddim_hacked.py
Normal file
486
iopaint/model/anytext/cldm/ddim_hacked.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, device, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f"Data shape for DDIM sampling is {size}, eta {eta}")
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
model_t = self.model.apply_model(x, t, c)
|
||||
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (
|
||||
model_t - model_uncond
|
||||
)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", "not implemented"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(
|
||||
self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
callback=None,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
num_reference_steps = timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
||||
t = torch.full(
|
||||
(x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long
|
||||
)
|
||||
if unconditional_guidance_scale == 1.0:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)),
|
||||
torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c)),
|
||||
),
|
||||
2,
|
||||
)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond
|
||||
)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = (
|
||||
alphas_next[i].sqrt()
|
||||
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
||||
* noise_pred
|
||||
)
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if (
|
||||
return_intermediates
|
||||
and i % (num_steps // return_intermediates) == 0
|
||||
and i < num_steps - 1
|
||||
):
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback:
|
||||
callback(i)
|
||||
|
||||
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({"intermediates": intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
||||
)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
if callback:
|
||||
callback(i)
|
||||
return x_dec
|
||||
165
iopaint/model/anytext/cldm/embedding_manager.py
Normal file
165
iopaint/model/anytext/cldm/embedding_manager.py
Normal file
@@ -0,0 +1,165 @@
|
||||
'''
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
token = token[0, 1]
|
||||
return token
|
||||
|
||||
|
||||
def get_clip_vision_emb(encoder, processor, img):
|
||||
_img = img.repeat(1, 3, 1, 1)*255
|
||||
inputs = processor(images=_img, return_tensors="pt")
|
||||
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
|
||||
outputs = encoder(**inputs)
|
||||
emb = outputs.image_embeds
|
||||
return emb
|
||||
|
||||
|
||||
def get_recog_emb(encoder, img_list):
|
||||
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
|
||||
encoder.predictor.eval()
|
||||
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
|
||||
return preds_neck
|
||||
|
||||
|
||||
def pad_H(x):
|
||||
_, _, H, W = x.shape
|
||||
p_top = (W - H) // 2
|
||||
p_bot = W - H - p_top
|
||||
return F.pad(x, (0, 0, p_top, p_bot))
|
||||
|
||||
|
||||
class EncodeNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(EncodeNet, self).__init__()
|
||||
chan = 16
|
||||
n_layer = 4 # downsample
|
||||
|
||||
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
|
||||
self.conv_list = nn.ModuleList([])
|
||||
_c = chan
|
||||
for i in range(n_layer):
|
||||
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
|
||||
_c *= 2
|
||||
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(self.conv1(x))
|
||||
for layer in self.conv_list:
|
||||
x = self.act(layer(x))
|
||||
x = self.act(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
valid=True,
|
||||
glyph_channels=20,
|
||||
position_channels=1,
|
||||
placeholder_string='*',
|
||||
add_pos=False,
|
||||
emb_type='ocr',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
|
||||
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
||||
token_dim = 768
|
||||
if hasattr(embedder, 'vit'):
|
||||
assert emb_type == 'vit'
|
||||
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
|
||||
self.get_recog_emb = None
|
||||
else: # using LDM's BERT encoder
|
||||
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
|
||||
token_dim = 1280
|
||||
self.token_dim = token_dim
|
||||
self.emb_type = emb_type
|
||||
|
||||
self.add_pos = add_pos
|
||||
if add_pos:
|
||||
self.position_encoder = EncodeNet(position_channels, token_dim)
|
||||
if emb_type == 'ocr':
|
||||
self.proj = linear(40*64, token_dim)
|
||||
if emb_type == 'conv':
|
||||
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
|
||||
|
||||
self.placeholder_token = get_token_for_string(placeholder_string)
|
||||
|
||||
def encode_text(self, text_info):
|
||||
if self.get_recog_emb is None and self.emb_type == 'ocr':
|
||||
self.get_recog_emb = partial(get_recog_emb, self.recog)
|
||||
|
||||
gline_list = []
|
||||
pos_list = []
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
for j in range(n_lines): # line
|
||||
gline_list += [text_info['gly_line'][j][i:i+1]]
|
||||
if self.add_pos:
|
||||
pos_list += [text_info['positions'][j][i:i+1]]
|
||||
|
||||
if len(gline_list) > 0:
|
||||
if self.emb_type == 'ocr':
|
||||
recog_emb = self.get_recog_emb(gline_list)
|
||||
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
|
||||
elif self.emb_type == 'vit':
|
||||
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
|
||||
elif self.emb_type == 'conv':
|
||||
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
|
||||
if self.add_pos:
|
||||
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
|
||||
enc_glyph = enc_glyph+enc_pos
|
||||
|
||||
self.text_embs_all = []
|
||||
n_idx = 0
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
text_embs = []
|
||||
for j in range(n_lines): # line
|
||||
text_embs += [enc_glyph[n_idx:n_idx+1]]
|
||||
n_idx += 1
|
||||
self.text_embs_all += [text_embs]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
b, device = tokenized_text.shape[0], tokenized_text.device
|
||||
for i in range(b):
|
||||
idx = tokenized_text[i] == self.placeholder_token.to(device)
|
||||
if sum(idx) > 0:
|
||||
if i >= len(self.text_embs_all):
|
||||
print('truncation for log images...')
|
||||
break
|
||||
text_emb = torch.cat(self.text_embs_all[i], dim=0)
|
||||
if sum(idx) != len(text_emb):
|
||||
print('truncation for long caption...')
|
||||
embedded_text[i][idx] = text_emb[:sum(idx)]
|
||||
return embedded_text
|
||||
|
||||
def embedding_parameters(self):
|
||||
return self.parameters()
|
||||
111
iopaint/model/anytext/cldm/hack.py
Normal file
111
iopaint/model/anytext/cldm/hack.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import einops
|
||||
|
||||
import iopaint.model.anytext.ldm.modules.encoders.modules
|
||||
import iopaint.model.anytext.ldm.modules.attention
|
||||
|
||||
from transformers import logging
|
||||
from iopaint.model.anytext.ldm.modules.attention import default
|
||||
|
||||
|
||||
def disable_verbosity():
|
||||
logging.set_verbosity_error()
|
||||
print('logging improved.')
|
||||
return
|
||||
|
||||
|
||||
def enable_sliced_attention():
|
||||
iopaint.model.anytext.ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
||||
print('Enabled sliced_attention.')
|
||||
return
|
||||
|
||||
|
||||
def hack_everything(clip_skip=0):
|
||||
disable_verbosity()
|
||||
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
||||
iopaint.model.anytext.ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
||||
print('Enabled clip hacks.')
|
||||
return
|
||||
|
||||
|
||||
# Written by Lvmin
|
||||
def _hacked_clip_forward(self, text):
|
||||
PAD = self.tokenizer.pad_token_id
|
||||
EOS = self.tokenizer.eos_token_id
|
||||
BOS = self.tokenizer.bos_token_id
|
||||
|
||||
def tokenize(t):
|
||||
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
def transformer_encode(t):
|
||||
if self.clip_skip > 1:
|
||||
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
||||
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
||||
else:
|
||||
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
||||
|
||||
def split(x):
|
||||
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
||||
|
||||
def pad(x, p, i):
|
||||
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
||||
|
||||
raw_tokens_list = tokenize(text)
|
||||
tokens_list = []
|
||||
|
||||
for raw_tokens in raw_tokens_list:
|
||||
raw_tokens_123 = split(raw_tokens)
|
||||
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
||||
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
||||
tokens_list.append(raw_tokens_123)
|
||||
|
||||
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
||||
|
||||
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
||||
y = transformer_encode(feed)
|
||||
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
||||
|
||||
return z
|
||||
|
||||
|
||||
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
||||
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = 1
|
||||
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range(0, limit, att_step):
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i + att_step, :, :] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
40
iopaint/model/anytext/cldm/model.py
Normal file
40
iopaint/model/anytext/cldm/model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from iopaint.model.anytext.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def get_state_dict(d):
|
||||
return d.get("state_dict", d)
|
||||
|
||||
|
||||
def load_state_dict(ckpt_path, location="cpu"):
|
||||
_, extension = os.path.splitext(ckpt_path)
|
||||
if extension.lower() == ".safetensors":
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
||||
else:
|
||||
state_dict = get_state_dict(
|
||||
torch.load(ckpt_path, map_location=torch.device(location))
|
||||
)
|
||||
state_dict = get_state_dict(state_dict)
|
||||
print(f"Loaded state_dict from [{ckpt_path}]")
|
||||
return state_dict
|
||||
|
||||
|
||||
def create_model(config_path, device, cond_stage_path=None, use_fp16=False):
|
||||
config = OmegaConf.load(config_path)
|
||||
if cond_stage_path:
|
||||
config.model.params.cond_stage_config.params.version = (
|
||||
cond_stage_path # use pre-downloaded ckpts, in case blocked
|
||||
)
|
||||
config.model.params.cond_stage_config.params.device = device
|
||||
if use_fp16:
|
||||
config.model.params.use_fp16 = True
|
||||
config.model.params.control_stage_config.params.use_fp16 = True
|
||||
config.model.params.unet_config.params.use_fp16 = True
|
||||
model = instantiate_from_config(config.model).cpu()
|
||||
print(f"Loaded model config from [{config_path}]")
|
||||
return model
|
||||
300
iopaint/model/anytext/cldm/recognizer.py
Executable file
300
iopaint/model/anytext/cldm/recognizer.py
Executable file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import traceback
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
from iopaint.model.anytext.ocr_recog.RecModel import RecModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def min_bounding_rect(img):
|
||||
ret, thresh = cv2.threshold(img, 127, 255, 0)
|
||||
contours, hierarchy = cv2.findContours(
|
||||
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
if len(contours) == 0:
|
||||
print("Bad contours, using fake bbox...")
|
||||
return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
|
||||
max_contour = max(contours, key=cv2.contourArea)
|
||||
rect = cv2.minAreaRect(max_contour)
|
||||
box = cv2.boxPoints(rect)
|
||||
box = np.int0(box)
|
||||
# sort
|
||||
x_sorted = sorted(box, key=lambda x: x[0])
|
||||
left = x_sorted[:2]
|
||||
right = x_sorted[2:]
|
||||
left = sorted(left, key=lambda x: x[1])
|
||||
(tl, bl) = left
|
||||
right = sorted(right, key=lambda x: x[1])
|
||||
(tr, br) = right
|
||||
if tl[1] > bl[1]:
|
||||
(tl, bl) = (bl, tl)
|
||||
if tr[1] > br[1]:
|
||||
(tr, br) = (br, tr)
|
||||
return np.array([tl, tr, br, bl])
|
||||
|
||||
|
||||
def create_predictor(model_dir=None, model_lang="ch", is_onnx=False):
|
||||
model_file_path = model_dir
|
||||
if model_file_path is not None and not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(model_file_path))
|
||||
|
||||
if is_onnx:
|
||||
import onnxruntime as ort
|
||||
|
||||
sess = ort.InferenceSession(
|
||||
model_file_path, providers=["CPUExecutionProvider"]
|
||||
) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
return sess
|
||||
else:
|
||||
if model_lang == "ch":
|
||||
n_class = 6625
|
||||
elif model_lang == "en":
|
||||
n_class = 97
|
||||
else:
|
||||
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
|
||||
rec_config = edict(
|
||||
in_channels=3,
|
||||
backbone=edict(
|
||||
type="MobileNetV1Enhance",
|
||||
scale=0.5,
|
||||
last_conv_stride=[1, 2],
|
||||
last_pool_type="avg",
|
||||
),
|
||||
neck=edict(
|
||||
type="SequenceEncoder",
|
||||
encoder_type="svtr",
|
||||
dims=64,
|
||||
depth=2,
|
||||
hidden_dims=120,
|
||||
use_guide=True,
|
||||
),
|
||||
head=edict(
|
||||
type="CTCHead",
|
||||
fc_decay=0.00001,
|
||||
out_channels=n_class,
|
||||
return_feats=True,
|
||||
),
|
||||
)
|
||||
|
||||
rec_model = RecModel(rec_config)
|
||||
if model_file_path is not None:
|
||||
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
|
||||
rec_model.eval()
|
||||
return rec_model.eval()
|
||||
|
||||
|
||||
def _check_image_file(path):
|
||||
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff"}
|
||||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
if os.path.isfile(img_file) and _check_image_file(img_file):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
imgs_lists = sorted(imgs_lists)
|
||||
return imgs_lists
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, args, predictor):
|
||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.predictor = predictor
|
||||
self.chars = self.get_char_dict(args.rec_char_dict_path)
|
||||
self.char2id = {x: i for i, x in enumerate(self.chars)}
|
||||
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
|
||||
self.use_fp16 = args.use_fp16
|
||||
|
||||
# img: CHW
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
assert imgC == img.shape[0]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
|
||||
h, w = img.shape[1:]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
img.unsqueeze(0),
|
||||
size=(imgH, resized_w),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
resized_image /= 255.0
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
|
||||
padding_im[:, :, 0:resized_w] = resized_image[0]
|
||||
return padding_im
|
||||
|
||||
# img_list: list of tensors with shape chw 0-255
|
||||
def pred_imglist(self, img_list, show_debug=False, is_ori=False):
|
||||
img_num = len(img_list)
|
||||
assert img_num > 0
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[2] / float(img.shape[1]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = torch.from_numpy(np.argsort(np.array(width_list)))
|
||||
batch_num = self.rec_batch_num
|
||||
preds_all = [None] * img_num
|
||||
preds_neck_all = [None] * img_num
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[1:]
|
||||
if h > w * 1.2:
|
||||
img = img_list[indices[ino]]
|
||||
img = torch.transpose(img, 1, 2).flip(dims=[1])
|
||||
img_list[indices[ino]] = img
|
||||
h, w = img.shape[1:]
|
||||
# wh_ratio = w * 1.0 / h
|
||||
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
|
||||
if self.use_fp16:
|
||||
norm_img = norm_img.half()
|
||||
norm_img = norm_img.unsqueeze(0)
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = torch.cat(norm_img_batch, dim=0)
|
||||
if show_debug:
|
||||
for i in range(len(norm_img_batch)):
|
||||
_img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
|
||||
_img = (_img + 0.5) * 255
|
||||
_img = _img[:, :, ::-1]
|
||||
file_name = f"{indices[beg_img_no + i]}"
|
||||
file_name = file_name + "_ori" if is_ori else file_name
|
||||
cv2.imwrite(file_name + ".jpg", _img)
|
||||
if self.is_onnx:
|
||||
input_dict = {}
|
||||
input_dict[self.predictor.get_inputs()[0].name] = (
|
||||
norm_img_batch.detach().cpu().numpy()
|
||||
)
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
preds = {}
|
||||
preds["ctc"] = torch.from_numpy(outputs[0])
|
||||
preds["ctc_neck"] = [torch.zeros(1)] * img_num
|
||||
else:
|
||||
preds = self.predictor(norm_img_batch)
|
||||
for rno in range(preds["ctc"].shape[0]):
|
||||
preds_all[indices[beg_img_no + rno]] = preds["ctc"][rno]
|
||||
preds_neck_all[indices[beg_img_no + rno]] = preds["ctc_neck"][rno]
|
||||
|
||||
return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
|
||||
|
||||
def get_char_dict(self, character_dict_path):
|
||||
character_str = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
||||
character_str.append(line)
|
||||
dict_character = list(character_str)
|
||||
dict_character = ["sos"] + dict_character + [" "] # eos is space
|
||||
return dict_character
|
||||
|
||||
def get_text(self, order):
|
||||
char_list = [self.chars[text_id] for text_id in order]
|
||||
return "".join(char_list)
|
||||
|
||||
def decode(self, mat):
|
||||
text_index = mat.detach().cpu().numpy().argmax(axis=1)
|
||||
ignored_tokens = [0]
|
||||
selection = np.ones(len(text_index), dtype=bool)
|
||||
selection[1:] = text_index[1:] != text_index[:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index != ignored_token
|
||||
return text_index[selection], np.where(selection)[0]
|
||||
|
||||
def get_ctcloss(self, preds, gt_text, weight):
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
weight = torch.tensor(weight).to(preds.device)
|
||||
ctc_loss = torch.nn.CTCLoss(reduction="none")
|
||||
log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
|
||||
targets = []
|
||||
target_lengths = []
|
||||
for t in gt_text:
|
||||
targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]
|
||||
target_lengths += [len(t)]
|
||||
targets = torch.tensor(targets).to(preds.device)
|
||||
target_lengths = torch.tensor(target_lengths).to(preds.device)
|
||||
input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(
|
||||
preds.device
|
||||
)
|
||||
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
|
||||
loss = loss / input_lengths * weight
|
||||
return loss
|
||||
|
||||
|
||||
def main():
|
||||
rec_model_dir = "./ocr_weights/ppv3_rec.pth"
|
||||
predictor = create_predictor(rec_model_dir)
|
||||
args = edict()
|
||||
args.rec_image_shape = "3, 48, 320"
|
||||
args.rec_char_dict_path = "./ocr_weights/ppocr_keys_v1.txt"
|
||||
args.rec_batch_num = 6
|
||||
text_recognizer = TextRecognizer(args, predictor)
|
||||
image_dir = "./test_imgs_cn"
|
||||
gt_text = ["韩国小馆"] * 14
|
||||
|
||||
image_file_list = get_image_file_list(image_dir)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
|
||||
for image_file in image_file_list:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
print("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
|
||||
try:
|
||||
tic = time.time()
|
||||
times = []
|
||||
for i in range(10):
|
||||
preds, _ = text_recognizer.pred_imglist(img_list) # get text
|
||||
preds_all = preds.softmax(dim=2)
|
||||
times += [(time.time() - tic) * 1000.0]
|
||||
tic = time.time()
|
||||
print(times)
|
||||
print(np.mean(times[1:]) / len(preds_all))
|
||||
weight = np.ones(len(gt_text))
|
||||
loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
|
||||
for i in range(len(valid_image_file_list)):
|
||||
pred = preds_all[i]
|
||||
order, idx = text_recognizer.decode(pred)
|
||||
text = text_recognizer.get_text(order)
|
||||
print(
|
||||
f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}'
|
||||
)
|
||||
except Exception as E:
|
||||
print(traceback.format_exc(), E)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user