diff --git a/crypten/nn/__init__.py b/crypten/nn/__init__.py index 6cf51873..51d053dd 100644 --- a/crypten/nn/__init__.py +++ b/crypten/nn/__init__.py @@ -50,6 +50,7 @@ Mean, Module, ModuleDict, + ModuleList, Mul, Parameter, Pow, @@ -123,6 +124,7 @@ "Mean", "Module", "ModuleDict", + "ModuleList", "MSELoss", "Mul", "Parameter", diff --git a/crypten/nn/module.py b/crypten/nn/module.py index 4c1beb07..e8e78b75 100644 --- a/crypten/nn/module.py +++ b/crypten/nn/module.py @@ -778,6 +778,69 @@ def __init__(self, *module_list, input_names=None): self.output_names = [module_name] +class ModuleList(Module): + r"""Holds submodules in a list. + + :class:`~crypten.nn.ModuleList` can be indexed like a regular Python list, but + modules it contains are properly registered, and will be visible by all + :class:`~crypten.nn.Module` methods. + + Args: + modules (iterable, optional): an iterable of modules to add + + Example:: + + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) + + def forward(self, x): + # ModuleList can act as an iterable, or be indexed using ints + for i, l in enumerate(self.linears): + x = self.linears[i // 2](x) + l(x) + return x + """ + + def __init__(self, modules=None): + super(ModuleList, self).__init__() + if modules is not None: + self += modules + + def __dir__(self): + keys = super(ModuleList, self).__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def __delitem__(self, idx): + if isinstance(idx, slice): + for k in range(len(self._modules))[idx]: + del self._modules[str(k)] + else: + del self._modules[self._get_abs_string_index(idx)] + # To preserve numbering, self._modules is being reconstructed with modules after deletion + str_indices = [str(i) for i in range(len(self._modules))] + self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + + +__module_list_func_names = [ + "_get_abs_string_index", + "__getitem__", + "__setitem__", + "__len__", + "__iter__", + "__iadd__", + "insert", + "append", + "extend", +] + + +for func_name in __module_list_func_names: + func = getattr(torch.nn.ModuleList, func_name) + setattr(ModuleList, func_name, func) + + class ModuleDict(Module): r"""Holds submodules in a dictionary. diff --git a/test/test_nn.py b/test/test_nn.py index e656fdba..e77fb2c5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1625,6 +1625,59 @@ def test_module_dict(self): module_dict.clear() self.assertEqual(len(module_dict), 0, "ModuleDict clear failed") + def test_module_list(self): + """Test ModuleDict module""" + module_list = crypten.nn.ModuleList() + self.assertEqual(len(module_list), 0, "ModuleList initialized incorrect size") + + # Test initialization + module_list = crypten.nn.ModuleList( + [crypten.nn.Conv2d(10, 10, 3), crypten.nn.MaxPool2d(3)] + ) + self.assertEqual(len(module_list), 2, "ModuleList initialized incorrect size") + self.assertTrue( + isinstance(module_list[0], crypten.nn.Conv2d), + "ModuleList init failed", + ) + self.assertTrue( + isinstance(module_list[1], crypten.nn.MaxPool2d), + "ModuleList init failed", + ) + + # Test append + module_list.append(crypten.nn.ReLU()) + self.assertEqual(len(module_list), 3, "ModuleList append failed") + self.assertTrue( + isinstance(module_list[2], crypten.nn.ReLU), + "ModuleList append failed", + ) + + # Test extend + module_list.extend([crypten.nn.Linear(10, 5), crypten.nn.ReLU()]) + msg = "ModuleList append failed" + self.assertEqual(len(module_list), 5, msg) + self.assertTrue(isinstance(module_list[3], crypten.nn.Linear), msg) + self.assertTrue(isinstance(module_list[4], crypten.nn.ReLU), msg) + + # Test insert + module_list.insert(1, crypten.nn.Sigmoid()) + msg = "ModuleList append failed" + self.assertEqual(len(module_list), 6, msg) + self.assertTrue(isinstance(module_list[1], crypten.nn.Sigmoid), msg) + + # Test __delitem__ + del module_list[1] + msg = "ModuleList delitem failed" + self.assertEqual(len(module_list), 5, msg) + self.assertTrue(isinstance(module_list[1], crypten.nn.MaxPool2d), msg) + + # Test __delitem__ with slice + del module_list[1:3] + msg = "ModuleList delitem failed with slice input" + self.assertEqual(len(module_list), 3, msg) + self.assertTrue(isinstance(module_list[0], crypten.nn.Conv2d), msg) + self.assertTrue(isinstance(module_list[1], crypten.nn.Linear), msg) + def test_parameter_initializations(self): """Test crypten.nn.init initializations""" sizes = [