anytext init
This commit is contained in:
48
iopaint/model/anytext/ocr_recog/RecCTCHead.py
Executable file
48
iopaint/model/anytext/ocr_recog/RecCTCHead.py
Executable file
@@ -0,0 +1,48 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CTCHead(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=6625,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
return_feats=False,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
bias=True,)
|
||||
else:
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
bias=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = dict()
|
||||
result['ctc'] = predicts
|
||||
result['ctc_neck'] = x
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user