Skip to content

Commit 016d0cc

Browse files
zou3519chsasank
authored andcommitted
Fix softmax warnings (pytorch#177)
1 parent 4330f79 commit 016d0cc

7 files changed

+113
-113
lines changed

beginner_source/nlp/deep_learning_tutorial.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@
123123
# Softmax is also in torch.nn.functional
124124
data = autograd.Variable(torch.randn(5))
125125
print(data)
126-
print(F.softmax(data))
127-
print(F.softmax(data).sum()) # Sums to 1 because it is a distribution!
128-
print(F.log_softmax(data)) # theres also log_softmax
126+
print(F.softmax(data, dim=0))
127+
print(F.softmax(data, dim=0).sum()) # Sums to 1 because it is a distribution!
128+
print(F.log_softmax(data, dim=0)) # theres also log_softmax
129129

130130

131131
######################################################################
@@ -277,7 +277,7 @@ def forward(self, bow_vec):
277277
# Pass the input through the linear layer,
278278
# then pass that through log_softmax.
279279
# Many non-linearities and other functions are in torch.nn.functional
280-
return F.log_softmax(self.linear(bow_vec))
280+
return F.log_softmax(self.linear(bow_vec), dim=1)
281281

282282

283283
def make_bow_vector(sentence, word_to_ix):

beginner_source/nlp/sequence_models_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, sentence):
180180
lstm_out, self.hidden = self.lstm(
181181
embeds.view(len(sentence), 1, -1), self.hidden)
182182
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
183-
tag_scores = F.log_softmax(tag_space)
183+
tag_scores = F.log_softmax(tag_space, dim=1)
184184
return tag_scores
185185

186186
######################################################################

beginner_source/nlp/word_embeddings_tutorial.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def forward(self, inputs):
230230
embeds = self.embeddings(inputs).view((1, -1))
231231
out = F.relu(self.linear1(embeds))
232232
out = self.linear2(out)
233-
log_probs = F.log_softmax(out)
233+
log_probs = F.log_softmax(out, dim=1)
234234
return log_probs
235235

236236

intermediate_source/char_rnn_classification_tutorial.py

+52-52
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,28 @@ def readLines(filename):
111111
# (language) to a list of lines (names). We also kept track of
112112
# ``all_categories`` (just a list of languages) and ``n_categories`` for
113113
# later reference.
114-
#
114+
#
115115

116116
print(category_lines['Italian'][:5])
117117

118118

119119
######################################################################
120120
# Turning Names into Tensors
121121
# --------------------------
122-
#
122+
#
123123
# Now that we have all the names organized, we need to turn them into
124124
# Tensors to make any use of them.
125-
#
125+
#
126126
# To represent a single letter, we use a "one-hot vector" of size
127127
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
128128
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
129-
#
129+
#
130130
# To make a word we join a bunch of those into a 2D matrix
131131
# ``<line_length x 1 x n_letters>``.
132-
#
132+
#
133133
# That extra 1 dimension is because PyTorch assumes everything is in
134134
# batches - we're just using a batch size of 1 here.
135-
#
135+
#
136136

137137
import torch
138138

@@ -162,36 +162,36 @@ def lineToTensor(line):
162162
######################################################################
163163
# Creating the Network
164164
# ====================
165-
#
165+
#
166166
# Before autograd, creating a recurrent neural network in Torch involved
167167
# cloning the parameters of a layer over several timesteps. The layers
168168
# held hidden state and gradients which are now entirely handled by the
169169
# graph itself. This means you can implement a RNN in a very "pure" way,
170170
# as regular feed-forward layers.
171-
#
171+
#
172172
# This RNN module (mostly copied from `the PyTorch for Torch users
173173
# tutorial <https://github.com/pytorch/tutorials/blob/master/Introduction%20to%20PyTorch%20for%20former%20Torchies.ipynb>`__)
174174
# is just 2 linear layers which operate on an input and hidden state, with
175175
# a LogSoftmax layer after the output.
176-
#
176+
#
177177
# .. figure:: https://i.imgur.com/Z2xbySO.png
178-
# :alt:
179-
#
180-
#
178+
# :alt:
179+
#
180+
#
181181

182182
import torch.nn as nn
183183
from torch.autograd import Variable
184184

185185
class RNN(nn.Module):
186186
def __init__(self, input_size, hidden_size, output_size):
187187
super(RNN, self).__init__()
188-
188+
189189
self.hidden_size = hidden_size
190-
190+
191191
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
192192
self.i2o = nn.Linear(input_size + hidden_size, output_size)
193-
self.softmax = nn.LogSoftmax()
194-
193+
self.softmax = nn.LogSoftmax(dim=1)
194+
195195
def forward(self, input, hidden):
196196
combined = torch.cat((input, hidden), 1)
197197
hidden = self.i2h(combined)
@@ -212,10 +212,10 @@ def initHidden(self):
212212
# initialize as zeros at first). We'll get back the output (probability of
213213
# each language) and a next hidden state (which we keep for the next
214214
# step).
215-
#
215+
#
216216
# Remember that PyTorch modules operate on Variables rather than straight
217217
# up Tensors.
218-
#
218+
#
219219

220220
input = Variable(letterToTensor('A'))
221221
hidden = Variable(torch.zeros(1, n_hidden))
@@ -228,7 +228,7 @@ def initHidden(self):
228228
# every step, so we will use ``lineToTensor`` instead of
229229
# ``letterToTensor`` and use slices. This could be further optimized by
230230
# pre-computing batches of Tensors.
231-
#
231+
#
232232

233233
input = Variable(lineToTensor('Albert'))
234234
hidden = Variable(torch.zeros(1, n_hidden))
@@ -240,21 +240,21 @@ def initHidden(self):
240240
######################################################################
241241
# As you can see the output is a ``<1 x n_categories>`` Tensor, where
242242
# every item is the likelihood of that category (higher is more likely).
243-
#
243+
#
244244

245245

246246
######################################################################
247-
#
247+
#
248248
# Training
249249
# ========
250250
# Preparing for Training
251251
# ----------------------
252-
#
252+
#
253253
# Before going into training we should make a few helper functions. The
254254
# first is to interpret the output of the network, which we know to be a
255255
# likelihood of each category. We can use ``Tensor.topk`` to get the index
256256
# of the greatest value:
257-
#
257+
#
258258

259259
def categoryFromOutput(output):
260260
top_n, top_i = output.data.topk(1) # Tensor out of Variable with .data
@@ -267,7 +267,7 @@ def categoryFromOutput(output):
267267
######################################################################
268268
# We will also want a quick way to get a training example (a name and its
269269
# language):
270-
#
270+
#
271271

272272
import random
273273

@@ -289,30 +289,30 @@ def randomTrainingExample():
289289
######################################################################
290290
# Training the Network
291291
# --------------------
292-
#
292+
#
293293
# Now all it takes to train this network is show it a bunch of examples,
294294
# have it make guesses, and tell it if it's wrong.
295-
#
295+
#
296296
# For the loss function ``nn.NLLLoss`` is appropriate, since the last
297297
# layer of the RNN is ``nn.LogSoftmax``.
298-
#
298+
#
299299

300300
criterion = nn.NLLLoss()
301301

302302

303303
######################################################################
304304
# Each loop of training will:
305-
#
305+
#
306306
# - Create input and target tensors
307307
# - Create a zeroed initial hidden state
308308
# - Read each letter in and
309-
#
309+
#
310310
# - Keep hidden state for next letter
311-
#
311+
#
312312
# - Compare final output to target
313313
# - Back-propagate
314314
# - Return the output and loss
315-
#
315+
#
316316

317317
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
318318

@@ -340,7 +340,7 @@ def train(category_tensor, line_tensor):
340340
# guesses and also keep track of loss for plotting. Since there are 1000s
341341
# of examples we print only every ``print_every`` examples, and take an
342342
# average of the loss.
343-
#
343+
#
344344

345345
import time
346346
import math
@@ -384,10 +384,10 @@ def timeSince(since):
384384
######################################################################
385385
# Plotting the Results
386386
# --------------------
387-
#
387+
#
388388
# Plotting the historical loss from ``all_losses`` shows the network
389389
# learning:
390-
#
390+
#
391391

392392
import matplotlib.pyplot as plt
393393
import matplotlib.ticker as ticker
@@ -399,13 +399,13 @@ def timeSince(since):
399399
######################################################################
400400
# Evaluating the Results
401401
# ======================
402-
#
402+
#
403403
# To see how well the network performs on different categories, we will
404404
# create a confusion matrix, indicating for every actual language (rows)
405405
# which language the network guesses (columns). To calculate the confusion
406406
# matrix a bunch of samples are run through the network with
407407
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
408-
#
408+
#
409409

410410
# Keep track of correct guesses in a confusion matrix
411411
confusion = torch.zeros(n_categories, n_categories)
@@ -414,10 +414,10 @@ def timeSince(since):
414414
# Just return an output given a line
415415
def evaluate(line_tensor):
416416
hidden = rnn.initHidden()
417-
417+
418418
for i in range(line_tensor.size()[0]):
419419
output, hidden = rnn(line_tensor[i], hidden)
420-
420+
421421
return output
422422

423423
# Go through a bunch of examples and record which are correctly guessed
@@ -455,13 +455,13 @@ def evaluate(line_tensor):
455455
# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish
456456
# for Italian. It seems to do very well with Greek, and very poorly with
457457
# English (perhaps because of overlap with other languages).
458-
#
458+
#
459459

460460

461461
######################################################################
462462
# Running on User Input
463463
# ---------------------
464-
#
464+
#
465465

466466
def predict(input_line, n_predictions=3):
467467
print('\n> %s' % input_line)
@@ -486,43 +486,43 @@ def predict(input_line, n_predictions=3):
486486
# The final versions of the scripts `in the Practical PyTorch
487487
# repo <https://github.com/spro/practical-pytorch/tree/master/char-rnn-classification>`__
488488
# split the above code into a few files:
489-
#
489+
#
490490
# - ``data.py`` (loads files)
491491
# - ``model.py`` (defines the RNN)
492492
# - ``train.py`` (runs training)
493493
# - ``predict.py`` (runs ``predict()`` with command line arguments)
494494
# - ``server.py`` (serve prediction as a JSON API with bottle.py)
495-
#
495+
#
496496
# Run ``train.py`` to train and save the network.
497-
#
497+
#
498498
# Run ``predict.py`` with a name to view predictions:
499-
#
499+
#
500500
# ::
501-
#
501+
#
502502
# $ python predict.py Hazaki
503503
# (-0.42) Japanese
504504
# (-1.39) Polish
505505
# (-3.51) Czech
506-
#
506+
#
507507
# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON
508508
# output of predictions.
509-
#
509+
#
510510

511511

512512
######################################################################
513513
# Exercises
514514
# =========
515-
#
515+
#
516516
# - Try with a different dataset of line -> category, for example:
517-
#
517+
#
518518
# - Any word -> language
519519
# - First name -> gender
520520
# - Character name -> writer
521521
# - Page title -> blog or subreddit
522-
#
522+
#
523523
# - Get better results with a bigger and/or better shaped network
524-
#
524+
#
525525
# - Add more linear layers
526526
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
527527
# - Combine multiple of these RNNs as a higher level network
528-
#
528+
#

0 commit comments

Comments
 (0)