Skip to content

Commit 8e534fc

Browse files
committed
Update with new Function API
1 parent 00e0058 commit 8e534fc

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

advanced_source/numpy_extensions_tutorial.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,21 @@ def incorrect_fft(input):
8686

8787

8888
class ScipyConv2dFunction(Function):
89-
90-
def forward(self, input, filter):
89+
@staticmethod
90+
def forward(ctx, input, filter):
9191
result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
92-
self.save_for_backward(input, filter)
92+
ctx.save_for_backward(input, filter)
9393
return torch.FloatTensor(result)
9494

95-
def backward(self, grad_output):
96-
input, filter = self.saved_tensors
95+
@staticmethod
96+
def backward(ctx, grad_output):
97+
input, filter = ctx.saved_tensors
98+
grad_output = grad_output.data
9799
grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
98100
grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
99-
return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)
101+
102+
return Variable(torch.FloatTensor(grad_input)), \
103+
Variable(torch.FloatTensor(grad_filter))
100104

101105

102106
class ScipyConv2d(Module):
@@ -106,7 +110,7 @@ def __init__(self, kh, kw):
106110
self.filter = Parameter(torch.randn(kh, kw))
107111

108112
def forward(self, input):
109-
return ScipyConv2dFunction()(input, self.filter)
113+
return ScipyConv2dFunction.apply(input, self.filter)
110114

111115
###############################################################
112116
# **Example usage:**

beginner_source/examples_autograd/two_layer_net_custom_function.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,25 @@ class MyReLU(torch.autograd.Function):
2323
which operate on Tensors.
2424
"""
2525

26-
def forward(self, input):
26+
@staticmethod
27+
def forward(ctx, input):
2728
"""
28-
In the forward pass we receive a Tensor containing the input and return a
29-
Tensor containing the output. You can cache arbitrary Tensors for use in the
30-
backward pass using the save_for_backward method.
29+
In the forward pass we receive a Tensor containing the input and return
30+
a Tensor containing the output. ctx is a context object that can be used
31+
to stash information for backward computation. You can cache arbitrary
32+
objects for use in the backward pass using the ctx.save_for_backward method.
3133
"""
32-
self.save_for_backward(input)
34+
ctx.save_for_backward(input)
3335
return input.clamp(min=0)
3436

35-
def backward(self, grad_output):
37+
@staticmethod
38+
def backward(ctx, grad_output):
3639
"""
3740
In the backward pass we receive a Tensor containing the gradient of the loss
3841
with respect to the output, and we need to compute the gradient of the loss
3942
with respect to the input.
4043
"""
41-
input, = self.saved_tensors
44+
input, = ctx.saved_tensors
4245
grad_input = grad_output.clone()
4346
grad_input[input < 0] = 0
4447
return grad_input
@@ -61,8 +64,8 @@ def backward(self, grad_output):
6164

6265
learning_rate = 1e-6
6366
for t in range(500):
64-
# Construct an instance of our MyReLU class to use in our network
65-
relu = MyReLU()
67+
# To apply our Function, we use Function.apply method. We alias this as 'relu'.
68+
relu = MyReLU.apply
6669

6770
# Forward pass: compute predicted y using operations on Variables; we compute
6871
# ReLU using our custom autograd operation.

0 commit comments

Comments
 (0)