Skip to content

Commit

Permalink
Refactor attribute names in autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed May 1, 2017
1 parent 2197e4c commit 2ca787f
Show file tree
Hide file tree
Showing 33 changed files with 645 additions and 593 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,16 @@ def run(self):
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/grad_buffer.cpp",
"torch/csrc/autograd/input_buffer.cpp",
"torch/csrc/autograd/python_function.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_hook.cpp",
"torch/csrc/autograd/functions/batch_normalization.cpp",
"torch/csrc/autograd/functions/convolution.cpp",
"torch/csrc/autograd/functions/basic_ops.cpp",
"torch/csrc/autograd/functions/utils.cpp",
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/nn/THNN_generic.cpp",
]
Expand Down
2 changes: 1 addition & 1 deletion test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def to_gpu(obj, type_map={}):
elif torch.is_storage(obj):
return obj.new().resize_(obj.size()).copy_(obj)
elif isinstance(obj, Variable):
assert obj.creator is None
assert obj.is_leaf
t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
elif isinstance(obj, list):
Expand Down
36 changes: 18 additions & 18 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def test_volatile(self):
z = x ** 2
self.assertFalse(z.volatile)
self.assertTrue(z.requires_grad)
self.assertIsNotNone(z.creator)
self.assertIsNotNone(z.grad_fn)
z.backward(torch.ones(5, 5))
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)

w = z + y
self.assertTrue(w.volatile)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.creator)
self.assertIsNone(w.grad_fn)

def test_indexing(self):
x = torch.arange(1, 17).resize_(4, 4)
Expand Down Expand Up @@ -376,23 +376,23 @@ def test_backward_no_grad(self):
with self.assertRaises(RuntimeError):
torch.autograd.backward([b], [None])

def test_previous_functions(self):
def test_next_functions(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)

a = x + y
self.assertIsNotNone(a.creator)
previous_functions = a.creator.previous_functions
self.assertEqual(len(previous_functions), 2)
self.assertIs(previous_functions[0][0], x)
self.assertEqual(previous_functions[0][1], 0)
self.assertIs(previous_functions[1][0], y)
self.assertEqual(previous_functions[1][1], 0)
self.assertIsNotNone(a.grad_fn)
next_functions = a.grad_fn.next_functions
self.assertEqual(len(next_functions), 2)
self.assertIs(next_functions[0][0], x)
self.assertEqual(next_functions[0][1], 0)
self.assertIs(next_functions[1][0], y)
self.assertEqual(next_functions[1][1], 0)

b = a + 5
previous_functions = b.creator.previous_functions
self.assertEqual(len(previous_functions), 1)
self.assertIs(previous_functions[0][0], a.creator)
next_functions = b.grad_fn.next_functions
self.assertEqual(len(next_functions), 1)
self.assertIs(next_functions[0][0], a.grad_fn)

def test_inplace(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
Expand Down Expand Up @@ -543,7 +543,7 @@ def __del__(self):
gc.collect()

for i in range(10):
Variable(torch.randn(10, 10), creator=CollectOnDelete())
Variable(torch.randn(10, 10), grad_fn=CollectOnDelete())

@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
"CUDA not available or <2 GPUs detected")
Expand All @@ -567,7 +567,7 @@ def test_detach(self):
y = x * 2
y = y.detach()
self.assertFalse(y.requires_grad)
self.assertIsNone(y.creator)
self.assertIsNone(y.grad_fn)
z = x + y
z.sum().backward()
# This is an incorrect gradient, but we assume that's what the user
Expand Down Expand Up @@ -669,7 +669,7 @@ def backward(self, grad_a, grad_b):
fn = Inplace(True)
q, p = fn(x, y)
self.assertIs(q, x)
self.assertIs(q.creator, fn)
self.assertIs(q.grad_fn, fn)
self.assertTrue(q.requires_grad)
q.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5, 5))
Expand All @@ -682,7 +682,7 @@ def test_leaf_assignment(self):
x[0] = y
x[1] = 2 * z
self.assertTrue(x.requires_grad)
self.assertIsNot(x.creator, None)
self.assertIsNot(x.grad_fn, None)
x.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5))
self.assertEqual(z.grad.data, torch.ones(5) * 2)
Expand Down Expand Up @@ -1293,7 +1293,7 @@ def unpack_variables(args):
def do_test(self, cls=cls, constructor_args=new_constructor_args,
call_args=call_args, test_name=test_name):
input = create_input(call_args)
self.assertEqual(gradcheck(cls(*constructor_args), input, eps=1e-6, atol=PRECISION), True)
self.assertTrue(gradcheck(lambda *input: cls(*constructor_args)(*input), input, eps=1e-6, atol=PRECISION))

if test_name not in ignore_inplace and issubclass(cls, InplaceFunction):
output = cls(*constructor_args)(*input)
Expand Down
5 changes: 4 additions & 1 deletion torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def backward(variables, grad_variables, retain_variables=False):
specify ``True`` if you want to differentiate some subgraph multiple
times.
"""
grad_variables = tuple(var if isinstance(var, Variable) or var is None
else Variable(var, volatile=True)
for var in grad_variables)
Variable._execution_engine.run_backward(
tuple(variables), tuple(grad_variables), retain_variables)
tuple(variables), grad_variables, retain_variables)

assert torch._C._autograd_init()
85 changes: 0 additions & 85 deletions torch/autograd/engine.py

This file was deleted.

4 changes: 0 additions & 4 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ class Function(_C._FunctionBase):
num_outputs: Number of tensors returned by :func:`forward`.
requires_grad: Boolean indicating whether the :func:`backward` will
ever need to be called.
previous_functions: Tuple of (int, Function) pairs of length
:attr:`num_inputs`. Each entry contains a reference to a
:class:`Function` that created corresponding input, and an index
of the previous function output that's been used.
"""
__call__ = _C._FunctionBase._do_forward

Expand Down
2 changes: 1 addition & 1 deletion torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3):
def fn(input):
return _as_tuple(func(*input))[i].data

numerical = get_numerical_jacobian(fn, inputs, inputs, eps)
analytical = get_analytical_jacobian(_as_tuple(inputs), o)
numerical = get_numerical_jacobian(fn, inputs, inputs, eps)

for a, n in zip(analytical, numerical):
if not ((a - n).abs() <= (atol + rtol * n.abs())).all():
Expand Down
29 changes: 17 additions & 12 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Variable(_C._VariableBase):
Variable is a thin wrapper around a Tensor object, that also holds
the gradient w.r.t. to it, and a reference to a function that created it.
This reference allows retracing the whole chain of operations that
created the data. If the Variable has been created by the user, its creator
created the data. If the Variable has been created by the user, its grad_fn
will be ``None`` and we call such objects *leaf* Variables.
Since autograd only supports scalar valued function differentiation, grad
Expand All @@ -33,8 +33,9 @@ class Variable(_C._VariableBase):
inference mode, i.e. don't save the history. See
:ref:`excluding-subgraphs` for more details.
Can be changed only on leaf Variables.
creator: Function of which the variable was an output. For leaf
(user created) variables it's ``None``. Read-only attribute.
is_leaf: Boolean indicating if the Variable is a graph leaf (i.e
if it was created by the user).
grad_fn: Gradient function graph trace.
Parameters:
data (any tensor class): Tensor to wrap.
Expand Down Expand Up @@ -82,7 +83,7 @@ def __setitem__(self, key, value):
return SetItem(key, value)(self)

def __deepcopy__(self, memo):
if self.creator is not None:
if not self.is_leaf:
raise RuntimeError("Only Variables created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
result = type(self)(self.data.clone())
Expand All @@ -106,7 +107,7 @@ def __setstate__(self, state):
# legacy serialization of Variable
self.data = state[0]
state = (state[3], state[4], state[2])
if self.creator is not None:
if not self.is_leaf:
raise RuntimeError('__setstate__ can be only called on leaf variables')
self.requires_grad, self.volatile, self._backward_hooks = state

Expand Down Expand Up @@ -143,6 +144,10 @@ def backward(self, gradient=None, retain_variables=False):
'backward should be called only on a scalar (i.e. 1-element tensor) '
'or with gradient w.r.t. the variable')
gradient = self.data.new().resize_as_(self.data).fill_(1)
if not isinstance(gradient, Variable):
if gradient is not None and not torch.is_tensor(gradient):
raise TypeError("gradient has to be a Tensor, Variable or None")
gradient = Variable(gradient, volatile=True)
self._execution_engine.run_backward((self,), (gradient,), retain_variables)

def register_hook(self, hook):
Expand Down Expand Up @@ -177,8 +182,8 @@ def register_hook(self, hook):
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.creator is not None:
self.creator._register_hook_dict(self)
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
Expand All @@ -194,10 +199,10 @@ def reinforce(self, reward):
reward(Tensor): Tensor with per-element rewards. It has to match
the device location and shape of Variable's data.
"""
if not isinstance(self.creator, StochasticFunction):
if not isinstance(self.grad_fn, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions")
self.creator._reinforce(reward)
self.grad_fn._reinforce(reward)

def detach(self):
"""Returns a new Variable, detached from the current graph.
Expand All @@ -212,12 +217,12 @@ def detach(self):
errors in correctness checks.
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._creator = None
result._grad_fn = None
return result

def detach_(self):
"""Detaches the Variable from the graph that created it, making it a leaf."""
self._creator = None
self._grad_fn = None
self.requires_grad = False

def contiguous(self):
Expand Down Expand Up @@ -895,5 +900,5 @@ def addr(cls, *args):
setattr(Variable._torch, method, as_static)


from .engine import ImperativeEngine
from torch._C import _ImperativeEngine as ImperativeEngine
Variable._execution_engine = ImperativeEngine()
Loading

0 comments on commit 2ca787f

Please sign in to comment.