anytext init
This commit is contained in:
165
iopaint/model/anytext/cldm/embedding_manager.py
Normal file
165
iopaint/model/anytext/cldm/embedding_manager.py
Normal file
@@ -0,0 +1,165 @@
|
||||
'''
|
||||
Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from iopaint.model.anytext.ldm.modules.diffusionmodules.util import conv_nd, linear
|
||||
|
||||
|
||||
def get_clip_token_for_string(tokenizer, string):
|
||||
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"]
|
||||
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
return tokens[0, 1]
|
||||
|
||||
|
||||
def get_bert_token_for_string(tokenizer, string):
|
||||
token = tokenizer(string)
|
||||
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
|
||||
token = token[0, 1]
|
||||
return token
|
||||
|
||||
|
||||
def get_clip_vision_emb(encoder, processor, img):
|
||||
_img = img.repeat(1, 3, 1, 1)*255
|
||||
inputs = processor(images=_img, return_tensors="pt")
|
||||
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
|
||||
outputs = encoder(**inputs)
|
||||
emb = outputs.image_embeds
|
||||
return emb
|
||||
|
||||
|
||||
def get_recog_emb(encoder, img_list):
|
||||
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
|
||||
encoder.predictor.eval()
|
||||
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
|
||||
return preds_neck
|
||||
|
||||
|
||||
def pad_H(x):
|
||||
_, _, H, W = x.shape
|
||||
p_top = (W - H) // 2
|
||||
p_bot = W - H - p_top
|
||||
return F.pad(x, (0, 0, p_top, p_bot))
|
||||
|
||||
|
||||
class EncodeNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(EncodeNet, self).__init__()
|
||||
chan = 16
|
||||
n_layer = 4 # downsample
|
||||
|
||||
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
|
||||
self.conv_list = nn.ModuleList([])
|
||||
_c = chan
|
||||
for i in range(n_layer):
|
||||
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
|
||||
_c *= 2
|
||||
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.act(self.conv1(x))
|
||||
for layer in self.conv_list:
|
||||
x = self.act(layer(x))
|
||||
x = self.act(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
|
||||
class EmbeddingManager(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedder,
|
||||
valid=True,
|
||||
glyph_channels=20,
|
||||
position_channels=1,
|
||||
placeholder_string='*',
|
||||
add_pos=False,
|
||||
emb_type='ocr',
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
|
||||
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
|
||||
token_dim = 768
|
||||
if hasattr(embedder, 'vit'):
|
||||
assert emb_type == 'vit'
|
||||
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
|
||||
self.get_recog_emb = None
|
||||
else: # using LDM's BERT encoder
|
||||
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
|
||||
token_dim = 1280
|
||||
self.token_dim = token_dim
|
||||
self.emb_type = emb_type
|
||||
|
||||
self.add_pos = add_pos
|
||||
if add_pos:
|
||||
self.position_encoder = EncodeNet(position_channels, token_dim)
|
||||
if emb_type == 'ocr':
|
||||
self.proj = linear(40*64, token_dim)
|
||||
if emb_type == 'conv':
|
||||
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
|
||||
|
||||
self.placeholder_token = get_token_for_string(placeholder_string)
|
||||
|
||||
def encode_text(self, text_info):
|
||||
if self.get_recog_emb is None and self.emb_type == 'ocr':
|
||||
self.get_recog_emb = partial(get_recog_emb, self.recog)
|
||||
|
||||
gline_list = []
|
||||
pos_list = []
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
for j in range(n_lines): # line
|
||||
gline_list += [text_info['gly_line'][j][i:i+1]]
|
||||
if self.add_pos:
|
||||
pos_list += [text_info['positions'][j][i:i+1]]
|
||||
|
||||
if len(gline_list) > 0:
|
||||
if self.emb_type == 'ocr':
|
||||
recog_emb = self.get_recog_emb(gline_list)
|
||||
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
|
||||
elif self.emb_type == 'vit':
|
||||
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
|
||||
elif self.emb_type == 'conv':
|
||||
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
|
||||
if self.add_pos:
|
||||
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
|
||||
enc_glyph = enc_glyph+enc_pos
|
||||
|
||||
self.text_embs_all = []
|
||||
n_idx = 0
|
||||
for i in range(len(text_info['n_lines'])): # sample index in a batch
|
||||
n_lines = text_info['n_lines'][i]
|
||||
text_embs = []
|
||||
for j in range(n_lines): # line
|
||||
text_embs += [enc_glyph[n_idx:n_idx+1]]
|
||||
n_idx += 1
|
||||
self.text_embs_all += [text_embs]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokenized_text,
|
||||
embedded_text,
|
||||
):
|
||||
b, device = tokenized_text.shape[0], tokenized_text.device
|
||||
for i in range(b):
|
||||
idx = tokenized_text[i] == self.placeholder_token.to(device)
|
||||
if sum(idx) > 0:
|
||||
if i >= len(self.text_embs_all):
|
||||
print('truncation for log images...')
|
||||
break
|
||||
text_emb = torch.cat(self.text_embs_all[i], dim=0)
|
||||
if sum(idx) != len(text_emb):
|
||||
print('truncation for long caption...')
|
||||
embedded_text[i][idx] = text_emb[:sum(idx)]
|
||||
return embedded_text
|
||||
|
||||
def embedding_parameters(self):
|
||||
return self.parameters()
|
||||
Reference in New Issue
Block a user