Skip to content

Commit

Permalink
ctc update
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Oct 22, 2019
1 parent ec53192 commit 8749902
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def demo(opt):

# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
_, preds_index = preds.permute(1, 0, 2).max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)

else:
Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def validation(model, criterion, evaluation_loader, converter, opt):

# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
preds_index = preds_index.view(-1)
preds_str = converter.decode(preds_index.data, preds_size.data)

else:
Expand Down
11 changes: 6 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
preds = preds.permute(1, 0, 2) # to use CTCLoss format
# permute 'preds' to use CTCloss format
cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0

# (ctc_a) To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True
# torch.backends.cudnn.enabled = False
# cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
# torch.backends.cudnn.enabled = True

# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
Expand Down

0 comments on commit 8749902

Please sign in to comment.