Skip to content

Commit

Permalink
[nn] add insert method to sequential class (pytorch#81402)
Browse files Browse the repository at this point in the history
Follows pytorch#71329

cc @kshitij12345
Pull Request resolved: pytorch#81402
Approved by: https://github.com/albanD
  • Loading branch information
khushi-411 authored and pytorchmergebot committed Jul 20, 2022
1 parent 596bb41 commit dced803
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
31 changes: 31 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,37 @@ def test_Sequential_append(self):
self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))

def test_Sequential_insert(self):
l1 = nn.Linear(1, 2)
l2 = nn.Linear(2, 3)
l3 = nn.Linear(3, 4)

n1 = nn.Sequential(l1, l2, l3)
module_1 = nn.Linear(4, 5)
n2 = nn.Sequential(l1, module_1, l2, l3)
self.assertEqual(n1.insert(1, module_1), n2)

# test for negative support
n3 = nn.Sequential(l1, l2, l3)
module_2 = nn.Linear(5, 6)
n4 = nn.Sequential(l1, module_2, l2, l3)
self.assertEqual(n3.insert(-2, module_2), n4)

def test_Sequential_insert_fail_case(self):
l1 = nn.Linear(1, 2)
l2 = nn.Linear(2, 3)
l3 = nn.Linear(3, 4)

module = nn.Linear(5, 6)

# test for error case
n1 = nn.Sequential(l1, l2, l3)
with self.assertRaises(IndexError):
n1.insert(-5, module)

with self.assertRaises(AssertionError):
n1.insert(1, [nn.Linear(6, 7)])

def test_Sequential_extend(self):
l1 = nn.Linear(10, 20)
l2 = nn.Linear(20, 30)
Expand Down
15 changes: 15 additions & 0 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ def append(self, module: Module) -> 'Sequential':
self.add_module(str(len(self)), module)
return self

def insert(self, index: int, module: Module) -> 'Sequential':
if not isinstance(module, Module):
raise AssertionError(
'module should be of type: {}'.format(Module))
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError(
'Index out of range: {}'.format(index))
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self

def extend(self, sequential) -> 'Sequential':
for layer in sequential:
self.append(layer)
Expand Down

0 comments on commit dced803

Please sign in to comment.