From 98f67e90d57cb03e367bd6c9b66f08f8d83a237c Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Wed, 19 Oct 2016 13:21:03 -0400 Subject: [PATCH] Fix super call in Container.modules and Container.parameters (#142) --- test/test_nn.py | 19 ++++++++++++++++--- torch/nn/modules/container.py | 6 ++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index f2eed1e16e2c77..45c4b395abe762 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -285,12 +285,25 @@ def __init__(self): l1=l, l2=l ) + self.param = Variable(torch.Tensor(3, 5)) l = nn.Linear(10, 20) n = Net() - s = nn.Sequential(l, l, l, l) + s = nn.Sequential(n, n, n, n) self.assertEqual(num_params(l), 2) - self.assertEqual(num_params(n), 2) - self.assertEqual(num_params(s), 2) + self.assertEqual(num_params(n), 3) + self.assertEqual(num_params(s), 3) + + def test_modules(self): + class Net(nn.Container): + def __init__(self): + super(Net, self).__init__() + self.l1 = l + self.l2 = l + self.param = Variable(torch.Tensor(3, 5)) + l = nn.Linear(10, 20) + n = Net() + s = nn.Sequential(n, n, n, n) + self.assertEqual(list(s.modules()), [s, n, l]) def test_Sequential_getitem(self): l1 = nn.Linear(10, 20) diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 0f3067d21de918..4c8c990bdd83aa 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -109,7 +109,8 @@ def load_parameter_dict(self, param_dict): def parameters(self, memo=None): if memo is None: memo = set() - super(Container, self).parameters(memo) + for p in super(Container, self).parameters(memo): + yield p for module in self.children(): for p in module.parameters(memo): yield p @@ -125,7 +126,8 @@ def modules(self, memo=None): if memo is None: memo = set() if self not in memo: - super(Container, self).modules(memo) + for m in super(Container, self).modules(memo): + yield m for module in self.children(): for m in module.modules(memo): yield m