Skip to content

Commit

Permalink
comment and multi-gpu setting update
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Apr 16, 2019
1 parent eb5570f commit c2e28f5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
3 changes: 2 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ class Model(nn.Module):

def __init__(self, opt):
super(Model, self).__init__()
self.opt = opt
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}

""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), batch_size=int(opt.batch_size/opt.num_gpu), I_channel_num=opt.input_channel)
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')

Expand Down
3 changes: 3 additions & 0 deletions modules/feature_extraction.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class VGG_FeatureExtractor(nn.Module):
""" FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """

def __init__(self, input_channel, output_channel=512):
super(VGG_FeatureExtractor, self).__init__()
Expand All @@ -28,6 +29,7 @@ def forward(self, input):


class RCNN_FeatureExtractor(nn.Module):
""" FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """

def __init__(self, input_channel, output_channel=512):
super(RCNN_FeatureExtractor, self).__init__()
Expand All @@ -50,6 +52,7 @@ def forward(self, input):


class ResNet_FeatureExtractor(nn.Module):
""" FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """

def __init__(self, input_channel, output_channel=512):
super(ResNet_FeatureExtractor, self).__init__()
Expand Down
6 changes: 3 additions & 3 deletions modules/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class TPS_SpatialTransformerNetwork(nn.Module):
""" Rectification Network of RARE, namely TPS based STN """

def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1):
def __init__(self, F, I_size, I_r_size, I_channel_num=1):
""" Based on RARE TPS
input:
batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
Expand All @@ -23,7 +23,7 @@ def __init__(self, F, I_size, I_r_size, batch_size, I_channel_num=1):
self.I_r_size = I_r_size # = (I_r_height, I_r_width)
self.I_channel_num = I_channel_num
self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
self.GridGenerator = GridGenerator(self.F, self.I_r_size, batch_size)
self.GridGenerator = GridGenerator(self.F, self.I_r_size)

def forward(self, batch_I):
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
Expand Down Expand Up @@ -81,7 +81,7 @@ def forward(self, batch_I):
class GridGenerator(nn.Module):
""" Grid Generator of RARE, which produces P_prime by multipling T with P """

def __init__(self, F, I_r_size, batch_size):
def __init__(self, F, I_r_size):
""" Generate P_hat and inv_delta_C for later """
super(GridGenerator, self).__init__()
self.eps = 1e-6
Expand Down
17 changes: 10 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def train(opt):
torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
print(best_model_log)
log.write(best_model_log+'\n')
log.write(best_model_log + '\n')

# save model per 1e+5 iter.
if (i + 1) % 1e+5 == 0:
Expand Down Expand Up @@ -259,14 +259,17 @@ def train(opt):
opt.num_gpu = torch.cuda.device_count()
# print('device count', opt.num_gpu)
if opt.num_gpu > 1:
opt.num_iter = int(opt.num_iter / opt.num_gpu)
opt.batch_size = opt.batch_size * opt.num_gpu
opt.workers = opt.workers * opt.num_gpu
print('------ Use multi-GPU setting ------')
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
# If you dont care about it, just commnet out these line.)
print(f'The batch_size is multiplied with num_gpu and multiplied batch_size is {opt.batch_size}')
print('if you stuck too long time with multi-GPU setting, try to set --workers 0')
# check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1
opt.workers = opt.workers * opt.num_gpu

""" previous version
print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size)
opt.batch_size = opt.batch_size * opt.num_gpu
print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.')
If you dont care about it, just commnet out these line.)
opt.num_iter = int(opt.num_iter / opt.num_gpu)
"""

train(opt)

0 comments on commit c2e28f5

Please sign in to comment.