Skip to content

Commit

Permalink
no_grad instead of requires_grad False
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Aug 3, 2019
1 parent d62daf3 commit 7498f48
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 52 deletions.
44 changes: 22 additions & 22 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,38 @@ def demo(opt):

# predict
model.eval()
for image_tensors, image_path_list in demo_loader:
batch_size = image_tensors.size(0)
with torch.no_grad():
with torch.no_grad():
for image_tensors, image_path_list in demo_loader:
batch_size = image_tensors.size(0)
image = image_tensors.cuda()
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred).log_softmax(2)
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred).log_softmax(2)

# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
_, preds_index = preds.permute(1, 0, 2).max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)
# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
_, preds_index = preds.permute(1, 0, 2).max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)

else:
preds = model(image, text_for_pred, is_train=False)
else:
preds = model(image, text_for_pred, is_train=False)

# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)
# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)

print('-' * 80)
print('image_path\tpredicted_labels')
print('-' * 80)
for img_name, pred in zip(image_path_list, preds_str):
if 'Attn' in opt.Prediction:
pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s])
print('-' * 80)
print('image_path\tpredicted_labels')
print('-' * 80)
for img_name, pred in zip(image_path_list, preds_str):
if 'Attn' in opt.Prediction:
pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s])

print(f'{img_name}\t{pred}')
print(f'{img_name}\t{pred}')


if __name__ == '__main__':
Expand Down
47 changes: 22 additions & 25 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=Fa

def validation(model, criterion, evaluation_loader, converter, opt):
""" validation or evaluation """
for p in model.parameters():
p.requires_grad = False

n_correct = 0
norm_ED = 0
length_of_data = 0
Expand All @@ -79,13 +76,12 @@ def validation(model, criterion, evaluation_loader, converter, opt):
for i, (image_tensors, labels) in enumerate(evaluation_loader):
batch_size = image_tensors.size(0)
length_of_data = length_of_data + batch_size
with torch.no_grad():
image = image_tensors.cuda()
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)
image = image_tensors.cuda()
# For max length prediction
length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)

start_time = time.time()
if 'CTC' in opt.Prediction:
Expand Down Expand Up @@ -170,22 +166,23 @@ def test(opt):

""" evaluation """
model.eval()
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
benchmark_all_eval(model, criterion, converter, opt)
else:
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=opt.batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, _, _, _, _, _ = validation(
model, criterion, evaluation_loader, converter, opt)

print(accuracy_by_best_model)
with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log:
log.write(str(accuracy_by_best_model) + '\n')
with torch.no_grad():
if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets
benchmark_all_eval(model, criterion, converter, opt)
else:
AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt)
evaluation_loader = torch.utils.data.DataLoader(
eval_data, batch_size=opt.batch_size,
shuffle=False,
num_workers=int(opt.workers),
collate_fn=AlignCollate_evaluation, pin_memory=True)
_, accuracy_by_best_model, _, _, _, _, _ = validation(
model, criterion, evaluation_loader, converter, opt)

print(accuracy_by_best_model)
with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log:
log.write(str(accuracy_by_best_model) + '\n')


if __name__ == '__main__':
Expand Down
8 changes: 3 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@ def train(opt):

while(True):
# train part
for p in model.parameters():
p.requires_grad = True

image_tensors, labels = train_dataset.get_batch()
image = image_tensors.cuda()
text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
Expand Down Expand Up @@ -156,8 +153,9 @@ def train(opt):
loss_avg.reset()

model.eval()
valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
model, criterion, valid_loader, converter, opt)
with torch.no_grad():
valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
model, criterion, valid_loader, converter, opt)
model.train()

for pred, gt in zip(preds[:5], labels[:5]):
Expand Down

0 comments on commit 7498f48

Please sign in to comment.