Skip to content

Commit

Permalink
Fix super call in Container.modules and Container.parameters (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored Oct 19, 2016
1 parent fee67c2 commit 98f67e9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
19 changes: 16 additions & 3 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,25 @@ def __init__(self):
l1=l,
l2=l
)
self.param = Variable(torch.Tensor(3, 5))
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(l, l, l, l)
s = nn.Sequential(n, n, n, n)
self.assertEqual(num_params(l), 2)
self.assertEqual(num_params(n), 2)
self.assertEqual(num_params(s), 2)
self.assertEqual(num_params(n), 3)
self.assertEqual(num_params(s), 3)

def test_modules(self):
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__()
self.l1 = l
self.l2 = l
self.param = Variable(torch.Tensor(3, 5))
l = nn.Linear(10, 20)
n = Net()
s = nn.Sequential(n, n, n, n)
self.assertEqual(list(s.modules()), [s, n, l])

def test_Sequential_getitem(self):
l1 = nn.Linear(10, 20)
Expand Down
6 changes: 4 additions & 2 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def load_parameter_dict(self, param_dict):
def parameters(self, memo=None):
if memo is None:
memo = set()
super(Container, self).parameters(memo)
for p in super(Container, self).parameters(memo):
yield p
for module in self.children():
for p in module.parameters(memo):
yield p
Expand All @@ -125,7 +126,8 @@ def modules(self, memo=None):
if memo is None:
memo = set()
if self not in memo:
super(Container, self).modules(memo)
for m in super(Container, self).modules(memo):
yield m
for module in self.children():
for m in module.modules(memo):
yield m
Expand Down

0 comments on commit 98f67e9

Please sign in to comment.