Skip to content

Commit

Permalink
Implement ModuleList (facebookresearch#244)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/CrypTen#244

Pull Request resolved: facebookresearch#321

Added ModuleList implementation to crypten/nn/module.py to address crypten feature request from Github:

facebookresearch#318

Reviewed By: lvdmaaten

Differential Revision: D31546144

fbshipit-source-id: ce3fd7ca043cb3b805f73812abf0d5e71e52a27d
  • Loading branch information
knottb authored and facebook-github-bot committed Oct 13, 2021
1 parent 09f6368 commit 3f190af
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
2 changes: 2 additions & 0 deletions crypten/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Mean,
Module,
ModuleDict,
ModuleList,
Mul,
Parameter,
Pow,
Expand Down Expand Up @@ -123,6 +124,7 @@
"Mean",
"Module",
"ModuleDict",
"ModuleList",
"MSELoss",
"Mul",
"Parameter",
Expand Down
63 changes: 63 additions & 0 deletions crypten/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,69 @@ def __init__(self, *module_list, input_names=None):
self.output_names = [module_name]


class ModuleList(Module):
r"""Holds submodules in a list.
:class:`~crypten.nn.ModuleList` can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
:class:`~crypten.nn.Module` methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
"""

def __init__(self, modules=None):
super(ModuleList, self).__init__()
if modules is not None:
self += modules

def __dir__(self):
keys = super(ModuleList, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys

def __delitem__(self, idx):
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
del self._modules[str(k)]
else:
del self._modules[self._get_abs_string_index(idx)]
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))


__module_list_func_names = [
"_get_abs_string_index",
"__getitem__",
"__setitem__",
"__len__",
"__iter__",
"__iadd__",
"insert",
"append",
"extend",
]


for func_name in __module_list_func_names:
func = getattr(torch.nn.ModuleList, func_name)
setattr(ModuleList, func_name, func)


class ModuleDict(Module):
r"""Holds submodules in a dictionary.
Expand Down
53 changes: 53 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,59 @@ def test_module_dict(self):
module_dict.clear()
self.assertEqual(len(module_dict), 0, "ModuleDict clear failed")

def test_module_list(self):
"""Test ModuleDict module"""
module_list = crypten.nn.ModuleList()
self.assertEqual(len(module_list), 0, "ModuleList initialized incorrect size")

# Test initialization
module_list = crypten.nn.ModuleList(
[crypten.nn.Conv2d(10, 10, 3), crypten.nn.MaxPool2d(3)]
)
self.assertEqual(len(module_list), 2, "ModuleList initialized incorrect size")
self.assertTrue(
isinstance(module_list[0], crypten.nn.Conv2d),
"ModuleList init failed",
)
self.assertTrue(
isinstance(module_list[1], crypten.nn.MaxPool2d),
"ModuleList init failed",
)

# Test append
module_list.append(crypten.nn.ReLU())
self.assertEqual(len(module_list), 3, "ModuleList append failed")
self.assertTrue(
isinstance(module_list[2], crypten.nn.ReLU),
"ModuleList append failed",
)

# Test extend
module_list.extend([crypten.nn.Linear(10, 5), crypten.nn.ReLU()])
msg = "ModuleList append failed"
self.assertEqual(len(module_list), 5, msg)
self.assertTrue(isinstance(module_list[3], crypten.nn.Linear), msg)
self.assertTrue(isinstance(module_list[4], crypten.nn.ReLU), msg)

# Test insert
module_list.insert(1, crypten.nn.Sigmoid())
msg = "ModuleList append failed"
self.assertEqual(len(module_list), 6, msg)
self.assertTrue(isinstance(module_list[1], crypten.nn.Sigmoid), msg)

# Test __delitem__
del module_list[1]
msg = "ModuleList delitem failed"
self.assertEqual(len(module_list), 5, msg)
self.assertTrue(isinstance(module_list[1], crypten.nn.MaxPool2d), msg)

# Test __delitem__ with slice
del module_list[1:3]
msg = "ModuleList delitem failed with slice input"
self.assertEqual(len(module_list), 3, msg)
self.assertTrue(isinstance(module_list[0], crypten.nn.Conv2d), msg)
self.assertTrue(isinstance(module_list[1], crypten.nn.Linear), msg)

def test_parameter_initializations(self):
"""Test crypten.nn.init initializations"""
sizes = [
Expand Down

0 comments on commit 3f190af

Please sign in to comment.