@@ -111,28 +111,28 @@ def readLines(filename):
111
111
# (language) to a list of lines (names). We also kept track of
112
112
# ``all_categories`` (just a list of languages) and ``n_categories`` for
113
113
# later reference.
114
- #
114
+ #
115
115
116
116
print (category_lines ['Italian' ][:5 ])
117
117
118
118
119
119
######################################################################
120
120
# Turning Names into Tensors
121
121
# --------------------------
122
- #
122
+ #
123
123
# Now that we have all the names organized, we need to turn them into
124
124
# Tensors to make any use of them.
125
- #
125
+ #
126
126
# To represent a single letter, we use a "one-hot vector" of size
127
127
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
128
128
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
129
- #
129
+ #
130
130
# To make a word we join a bunch of those into a 2D matrix
131
131
# ``<line_length x 1 x n_letters>``.
132
- #
132
+ #
133
133
# That extra 1 dimension is because PyTorch assumes everything is in
134
134
# batches - we're just using a batch size of 1 here.
135
- #
135
+ #
136
136
137
137
import torch
138
138
@@ -162,36 +162,36 @@ def lineToTensor(line):
162
162
######################################################################
163
163
# Creating the Network
164
164
# ====================
165
- #
165
+ #
166
166
# Before autograd, creating a recurrent neural network in Torch involved
167
167
# cloning the parameters of a layer over several timesteps. The layers
168
168
# held hidden state and gradients which are now entirely handled by the
169
169
# graph itself. This means you can implement a RNN in a very "pure" way,
170
170
# as regular feed-forward layers.
171
- #
171
+ #
172
172
# This RNN module (mostly copied from `the PyTorch for Torch users
173
173
# tutorial <https://github.com/pytorch/tutorials/blob/master/Introduction%20to%20PyTorch%20for%20former%20Torchies.ipynb>`__)
174
174
# is just 2 linear layers which operate on an input and hidden state, with
175
175
# a LogSoftmax layer after the output.
176
- #
176
+ #
177
177
# .. figure:: https://i.imgur.com/Z2xbySO.png
178
- # :alt:
179
- #
180
- #
178
+ # :alt:
179
+ #
180
+ #
181
181
182
182
import torch .nn as nn
183
183
from torch .autograd import Variable
184
184
185
185
class RNN (nn .Module ):
186
186
def __init__ (self , input_size , hidden_size , output_size ):
187
187
super (RNN , self ).__init__ ()
188
-
188
+
189
189
self .hidden_size = hidden_size
190
-
190
+
191
191
self .i2h = nn .Linear (input_size + hidden_size , hidden_size )
192
192
self .i2o = nn .Linear (input_size + hidden_size , output_size )
193
- self .softmax = nn .LogSoftmax ()
194
-
193
+ self .softmax = nn .LogSoftmax (dim = 1 )
194
+
195
195
def forward (self , input , hidden ):
196
196
combined = torch .cat ((input , hidden ), 1 )
197
197
hidden = self .i2h (combined )
@@ -212,10 +212,10 @@ def initHidden(self):
212
212
# initialize as zeros at first). We'll get back the output (probability of
213
213
# each language) and a next hidden state (which we keep for the next
214
214
# step).
215
- #
215
+ #
216
216
# Remember that PyTorch modules operate on Variables rather than straight
217
217
# up Tensors.
218
- #
218
+ #
219
219
220
220
input = Variable (letterToTensor ('A' ))
221
221
hidden = Variable (torch .zeros (1 , n_hidden ))
@@ -228,7 +228,7 @@ def initHidden(self):
228
228
# every step, so we will use ``lineToTensor`` instead of
229
229
# ``letterToTensor`` and use slices. This could be further optimized by
230
230
# pre-computing batches of Tensors.
231
- #
231
+ #
232
232
233
233
input = Variable (lineToTensor ('Albert' ))
234
234
hidden = Variable (torch .zeros (1 , n_hidden ))
@@ -240,21 +240,21 @@ def initHidden(self):
240
240
######################################################################
241
241
# As you can see the output is a ``<1 x n_categories>`` Tensor, where
242
242
# every item is the likelihood of that category (higher is more likely).
243
- #
243
+ #
244
244
245
245
246
246
######################################################################
247
- #
247
+ #
248
248
# Training
249
249
# ========
250
250
# Preparing for Training
251
251
# ----------------------
252
- #
252
+ #
253
253
# Before going into training we should make a few helper functions. The
254
254
# first is to interpret the output of the network, which we know to be a
255
255
# likelihood of each category. We can use ``Tensor.topk`` to get the index
256
256
# of the greatest value:
257
- #
257
+ #
258
258
259
259
def categoryFromOutput (output ):
260
260
top_n , top_i = output .data .topk (1 ) # Tensor out of Variable with .data
@@ -267,7 +267,7 @@ def categoryFromOutput(output):
267
267
######################################################################
268
268
# We will also want a quick way to get a training example (a name and its
269
269
# language):
270
- #
270
+ #
271
271
272
272
import random
273
273
@@ -289,30 +289,30 @@ def randomTrainingExample():
289
289
######################################################################
290
290
# Training the Network
291
291
# --------------------
292
- #
292
+ #
293
293
# Now all it takes to train this network is show it a bunch of examples,
294
294
# have it make guesses, and tell it if it's wrong.
295
- #
295
+ #
296
296
# For the loss function ``nn.NLLLoss`` is appropriate, since the last
297
297
# layer of the RNN is ``nn.LogSoftmax``.
298
- #
298
+ #
299
299
300
300
criterion = nn .NLLLoss ()
301
301
302
302
303
303
######################################################################
304
304
# Each loop of training will:
305
- #
305
+ #
306
306
# - Create input and target tensors
307
307
# - Create a zeroed initial hidden state
308
308
# - Read each letter in and
309
- #
309
+ #
310
310
# - Keep hidden state for next letter
311
- #
311
+ #
312
312
# - Compare final output to target
313
313
# - Back-propagate
314
314
# - Return the output and loss
315
- #
315
+ #
316
316
317
317
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
318
318
@@ -340,7 +340,7 @@ def train(category_tensor, line_tensor):
340
340
# guesses and also keep track of loss for plotting. Since there are 1000s
341
341
# of examples we print only every ``print_every`` examples, and take an
342
342
# average of the loss.
343
- #
343
+ #
344
344
345
345
import time
346
346
import math
@@ -384,10 +384,10 @@ def timeSince(since):
384
384
######################################################################
385
385
# Plotting the Results
386
386
# --------------------
387
- #
387
+ #
388
388
# Plotting the historical loss from ``all_losses`` shows the network
389
389
# learning:
390
- #
390
+ #
391
391
392
392
import matplotlib .pyplot as plt
393
393
import matplotlib .ticker as ticker
@@ -399,13 +399,13 @@ def timeSince(since):
399
399
######################################################################
400
400
# Evaluating the Results
401
401
# ======================
402
- #
402
+ #
403
403
# To see how well the network performs on different categories, we will
404
404
# create a confusion matrix, indicating for every actual language (rows)
405
405
# which language the network guesses (columns). To calculate the confusion
406
406
# matrix a bunch of samples are run through the network with
407
407
# ``evaluate()``, which is the same as ``train()`` minus the backprop.
408
- #
408
+ #
409
409
410
410
# Keep track of correct guesses in a confusion matrix
411
411
confusion = torch .zeros (n_categories , n_categories )
@@ -414,10 +414,10 @@ def timeSince(since):
414
414
# Just return an output given a line
415
415
def evaluate (line_tensor ):
416
416
hidden = rnn .initHidden ()
417
-
417
+
418
418
for i in range (line_tensor .size ()[0 ]):
419
419
output , hidden = rnn (line_tensor [i ], hidden )
420
-
420
+
421
421
return output
422
422
423
423
# Go through a bunch of examples and record which are correctly guessed
@@ -455,13 +455,13 @@ def evaluate(line_tensor):
455
455
# languages it guesses incorrectly, e.g. Chinese for Korean, and Spanish
456
456
# for Italian. It seems to do very well with Greek, and very poorly with
457
457
# English (perhaps because of overlap with other languages).
458
- #
458
+ #
459
459
460
460
461
461
######################################################################
462
462
# Running on User Input
463
463
# ---------------------
464
- #
464
+ #
465
465
466
466
def predict (input_line , n_predictions = 3 ):
467
467
print ('\n > %s' % input_line )
@@ -486,43 +486,43 @@ def predict(input_line, n_predictions=3):
486
486
# The final versions of the scripts `in the Practical PyTorch
487
487
# repo <https://github.com/spro/practical-pytorch/tree/master/char-rnn-classification>`__
488
488
# split the above code into a few files:
489
- #
489
+ #
490
490
# - ``data.py`` (loads files)
491
491
# - ``model.py`` (defines the RNN)
492
492
# - ``train.py`` (runs training)
493
493
# - ``predict.py`` (runs ``predict()`` with command line arguments)
494
494
# - ``server.py`` (serve prediction as a JSON API with bottle.py)
495
- #
495
+ #
496
496
# Run ``train.py`` to train and save the network.
497
- #
497
+ #
498
498
# Run ``predict.py`` with a name to view predictions:
499
- #
499
+ #
500
500
# ::
501
- #
501
+ #
502
502
# $ python predict.py Hazaki
503
503
# (-0.42) Japanese
504
504
# (-1.39) Polish
505
505
# (-3.51) Czech
506
- #
506
+ #
507
507
# Run ``server.py`` and visit http://localhost:5533/Yourname to get JSON
508
508
# output of predictions.
509
- #
509
+ #
510
510
511
511
512
512
######################################################################
513
513
# Exercises
514
514
# =========
515
- #
515
+ #
516
516
# - Try with a different dataset of line -> category, for example:
517
- #
517
+ #
518
518
# - Any word -> language
519
519
# - First name -> gender
520
520
# - Character name -> writer
521
521
# - Page title -> blog or subreddit
522
- #
522
+ #
523
523
# - Get better results with a bigger and/or better shaped network
524
- #
524
+ #
525
525
# - Add more linear layers
526
526
# - Try the ``nn.LSTM`` and ``nn.GRU`` layers
527
527
# - Combine multiple of these RNNs as a higher level network
528
- #
528
+ #
0 commit comments