Skip to content

Commit

Permalink
modify transformeroptim, resize
Browse files Browse the repository at this point in the history
  • Loading branch information
Topdu committed Aug 24, 2021
1 parent 73058cc commit 2bf8ad9
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 50 deletions.
8 changes: 5 additions & 3 deletions configs/rec/rec_mtb_nrtr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Architecture:
name: MTB
cnn_num: 2
Head:
name: TransformerOptim
name: Transformer
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
Expand All @@ -69,8 +69,9 @@ Train:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- PILResize:
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
Expand All @@ -88,8 +89,9 @@ Eval:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- PILResize:
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
Expand Down
2 changes: 1 addition & 1 deletion ppocr/data/imaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop

from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *
Expand Down
29 changes: 10 additions & 19 deletions ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,21 @@ def __call__(self, data):
data['image'] = norm_img
return data

class PILResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape

def __call__(self, data):
img = data['image']
image_pil = Image.fromarray(np.uint8(img))
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
norm_img = np.array(norm_img)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data


class CVResize(object):
def __init__(self, image_shape, **kwargs):
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type

def __call__(self, data):
img = data['image']
#print(img)
norm_img = cv2.resize(img,self.image_shape)
norm_img = np.expand_dims(norm_img, -1)
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
Expand Down
22 changes: 7 additions & 15 deletions ppocr/losses/rec_nrtr_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,26 @@
import paddle.nn.functional as F


def cal_performance(pred, tgt):

pred = pred.max(1)[1]
tgt = tgt.contiguous().view(-1)
non_pad_mask = tgt.ne(0)
n_correct = pred.eq(tgt)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return n_correct


class NRTRLoss(nn.Layer):
def __init__(self,smoothing=True, **kwargs):
def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing

def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:,1:2+max_len]
tgt = tgt.reshape([-1] )
tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:
Expand Down
4 changes: 2 additions & 2 deletions ppocr/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_nrtr_optim_head import TransformerOptim
from .rec_nrtr_head import Transformer

# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
]

#table head
Expand Down
4 changes: 2 additions & 2 deletions ppocr/modeling/heads/multiheadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ones_ = constant_(value=1.)


class MultiheadAttentionOptim(nn.Layer):
class MultiheadAttention(nn.Layer):
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
Expand All @@ -46,7 +46,7 @@ def __init__(self,
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttentionOptim, self).__init__()
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)


class TransformerOptim(nn.Layer):
class Transformer(nn.Layer):
"""A transformer model. User is able to modify the attributes as needed. The architechture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(self,
out_channels=0,
dst_vocab_size=99,
scale_embedding=True):
super(TransformerOptim, self).__init__()
super(Transformer, self).__init__()
self.embedding = Embeddings(
d_model=d_model,
vocab=dst_vocab_size,
Expand Down Expand Up @@ -215,8 +215,7 @@ def collect_active_part(beamed_tensor, curr_active_inst_idx,
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs)

beamed_tensor = beamed_tensor.reshape(
[n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape])
Expand Down Expand Up @@ -486,7 +485,7 @@ def __init__(self,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim(
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)

self.conv1 = Conv2D(
Expand Down Expand Up @@ -555,9 +554,9 @@ def __init__(self,
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim(
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.multihead_attn = MultiheadAttentionOptim(
self.multihead_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)

self.conv1 = Conv2D(
Expand Down

0 comments on commit 2bf8ad9

Please sign in to comment.