@@ -251,7 +251,7 @@ differentiation. This is something the PyTorch team is working on, but it is
251
251
not available yet. As such, we have to also implement the backward pass of our
252
252
LLTM, which computes the derivative of the loss with respect to each input of
253
253
the forward pass. Ultimately, we will plop both the forward and backward
254
- function into a :class: `torch.nn .Function ` to create a nice Python binding. The
254
+ function into a :class: `torch.autograd .Function ` to create a nice Python binding. The
255
255
backward function is slightly more involved, so we'll not dig deeper into the
256
256
code (if you are interested, `Alex Graves' thesis
257
257
<http://www.cs.toronto.edu/~graves/phd.pdf> `_ is a good read for more
@@ -415,7 +415,7 @@ matches our C++ code::
415
415
LLTM forward
416
416
417
417
Since we are now able to call our C++ functions from Python, we can wrap them
418
- with :class: `torch.nn .Function ` and :class: `torch.nn.Module ` to make them first
418
+ with :class: `torch.autograd .Function ` and :class: `torch.nn.Module ` to make them first
419
419
class citizens of PyTorch::
420
420
421
421
import math
@@ -424,7 +424,7 @@ class citizens of PyTorch::
424
424
# Our module!
425
425
import lltm
426
426
427
- class LLTMFunction(torch.nn .Function):
427
+ class LLTMFunction(torch.autograd .Function):
428
428
@staticmethod
429
429
def forward(ctx, input, weights, bias, old_h, old_cell):
430
430
outputs = lltm.forward(input, weights, bias, old_h, old_cell)
0 commit comments