1- """
2- Copyright (c) 2019-present NAVER Corp.
3-
4- Licensed under the Apache License, Version 2.0 (the "License");
5- you may not use this file except in compliance with the License.
6- You may obtain a copy of the License at
7-
8- http://www.apache.org/licenses/LICENSE-2.0
9-
10- Unless required by applicable law or agreed to in writing, software
11- distributed under the License is distributed on an "AS IS" BASIS,
12- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- See the License for the specific language governing permissions and
14- limitations under the License.
15- """
16-
171import torch .nn as nn
182
193from .modules .feature_extraction import (
@@ -34,14 +18,12 @@ def __init__(self, opt):
3418 self .stages = {'Trans' : opt .Transformation , 'Feat' : opt .FeatureExtraction ,
3519 'Seq' : opt .SequenceModeling , 'Pred' : opt .Prediction }
3620
37- """ Transformation """
3821 if opt .Transformation == 'TPS' :
3922 self .Transformation = TPS_SpatialTransformerNetwork (
4023 F = opt .num_fiducial , I_size = (opt .imgH , opt .imgW ), I_r_size = (opt .imgH , opt .imgW ), I_channel_num = opt .input_channel )
4124 else :
4225 print ('No Transformation module specified' )
4326
44- """ FeatureExtraction """
4527 if opt .FeatureExtraction == 'VGG' :
4628 self .FeatureExtraction = VGG_FeatureExtractor (opt .input_channel , opt .output_channel )
4729 elif opt .FeatureExtraction == 'RCNN' :
@@ -53,7 +35,6 @@ def __init__(self, opt):
5335 self .FeatureExtraction_output = opt .output_channel # int(imgH/16-1) * 512
5436 self .AdaptiveAvgPool = nn .AdaptiveAvgPool2d ((None , 1 )) # Transform final (imgH/16-1) -> 1
5537
56- """ Sequence modeling"""
5738 if opt .SequenceModeling == 'BiLSTM' :
5839 self .SequenceModeling = nn .Sequential (
5940 BidirectionalLSTM (self .FeatureExtraction_output , opt .hidden_size , opt .hidden_size ),
@@ -63,7 +44,6 @@ def __init__(self, opt):
6344 print ('No SequenceModeling module specified' )
6445 self .SequenceModeling_output = self .FeatureExtraction_output
6546
66- """ Prediction """
6747 if opt .Prediction == 'CTC' :
6848 self .Prediction = nn .Linear (self .SequenceModeling_output , opt .num_class )
6949 elif opt .Prediction == 'Attn' :
@@ -72,22 +52,18 @@ def __init__(self, opt):
7252 raise Exception ('Prediction is neither CTC or Attn' )
7353
7454 def forward (self , input , text , is_train = True ):
75- """ Transformation stage """
7655 if not self .stages ['Trans' ] == "None" :
7756 input = self .Transformation (input )
7857
79- """ Feature extraction stage """
8058 visual_feature = self .FeatureExtraction (input )
8159 visual_feature = self .AdaptiveAvgPool (visual_feature .permute (0 , 3 , 1 , 2 )) # [b, c, h, w] -> [b, w, c, h]
8260 visual_feature = visual_feature .squeeze (3 )
8361
84- """ Sequence modeling stage """
8562 if self .stages ['Seq' ] == 'BiLSTM' :
8663 contextual_feature = self .SequenceModeling (visual_feature )
8764 else :
8865 contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
8966
90- """ Prediction stage """
9167 if self .stages ['Pred' ] == 'CTC' :
9268 prediction = self .Prediction (contextual_feature .contiguous ())
9369 else :
0 commit comments