Skip to content

Commit

Permalink
ctc fix..
Browse files Browse the repository at this point in the history
  • Loading branch information
Baek JeongHun committed Oct 22, 2019
1 parent 8749902 commit a19dc0f
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ def train(opt):
if 'CTC' in opt.Prediction:
preds = model(image, text).log_softmax(2)
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
cost = criterion(preds.permute(1, 0, 2), text.to(device), preds_size.to(device), length.to(device)) # For PyTorch 1.3.0
preds = preds.permute(1, 0, 2)

# (ctc_a) For PyTorch 1.2.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
# https://github.com/jpuigcerver/PyLaia/issues/16
# torch.backends.cudnn.enabled = False
# cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
# torch.backends.cudnn.enabled = True
torch.backends.cudnn.enabled = False
cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
torch.backends.cudnn.enabled = True

# # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
# # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
Expand Down

0 comments on commit a19dc0f

Please sign in to comment.