diff --git a/docs/source/Library.md b/docs/source/Library.md index 463726f8b3..da2504e027 100644 --- a/docs/source/Library.md +++ b/docs/source/Library.md @@ -60,7 +60,7 @@ model.to(device) # Specify the tgt word generator and loss computation module model.generator = nn.Sequential( nn.Linear(rnn_size, len(tgt_vocab)), - nn.LogSoftmax(dim=-1)) + nn.LogSoftmax(dim=-1)).to(device) loss = onmt.utils.loss.NMTLossCompute( criterion=nn.NLLLoss(ignore_index=tgt_padding, reduction="sum"),