Skip to content

Commit

Permalink
Fix clone_module and MAML for RNN modules. (#140)
Browse files Browse the repository at this point in the history
* Fix clone_module and MAML for RNN modules.

* Version bump.
  • Loading branch information
seba-1511 authored Apr 24, 2020
1 parent a5c1ef2 commit 6ff649d
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 2 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

### Changed

### Fixed


## v0.1.1

### Added

* New tutorial: 'Feature Reuse with ANIL'. (@ewinapun)

### Changed
Expand All @@ -18,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

* `MAML()` and `clone_module` support for RNN modules.


## v0.1.0.1

Expand Down
2 changes: 1 addition & 1 deletion learn2learn/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0.1'
__version__ = '0.1.1'
5 changes: 5 additions & 0 deletions learn2learn/algorithms/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def maml_update(model, lr, grads=None):
model._modules[module_key] = maml_update(model._modules[module_key],
lr=lr,
grads=None)

# Finally, rebuild the flattened parameters for RNNs
# See this issue for more details:
# https://github.com/learnables/learn2learn/issues/139
model._apply(lambda x: x)
return model


Expand Down
7 changes: 6 additions & 1 deletion learn2learn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def clone_module(module):
# TODO: This function might require that module.forward()
# was called in order to work properly, if forward() instanciates
# new variables.
# TODO: deepcopy is expensive. We can probably get away with a shallowcopy.
# TODO: We can probably get away with a shallowcopy.
# However, since shallow copy does not recurse, we need to write a
# recursive version of shallow copy.
# NOTE: This can probably be implemented more cleanly with
Expand Down Expand Up @@ -119,6 +119,11 @@ def clone_module(module):
if hasattr(clone, '_modules'):
for module_key in clone._modules:
clone._modules[module_key] = clone_module(module._modules[module_key])

# Finally, rebuild the flattened parameters for RNNs
# See this issue for more details:
# https://github.com/learnables/learn2learn/issues/139
clone = clone._apply(lambda x: x)
return clone


Expand Down
52 changes: 52 additions & 0 deletions tests/unit/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@


def ref_clone_module(module):
"""
Note: This implementation does not work for RNNs.
It requires calling learner.rnn._apply(lambda x: x) before
each forward call.
See this issue for more details:
https://github.com/learnables/learn2learn/issues/139
"""
# First, create a copy of the module.
clone = copy.deepcopy(module)

Expand Down Expand Up @@ -120,6 +127,51 @@ def test_clone_module_models(self):
for ref_p, l2l_p in zip(ref_model.parameters(), l2l_model.parameters()):
self.assertTrue(torch.equal(ref_p, l2l_p))

def test_rnn_clone(self):
# Tests: https://github.com/learnables/learn2learn/issues/139
# The test is mainly about whether we can clone and adapt RNNs.
# See issue for details.
N_STEPS = 3
for rnn_class in [
torch.nn.RNN,
torch.nn.LSTM,
torch.nn.GRU,
]:
torch.manual_seed(1234)
model = rnn_class(2, 1)
maml = l2l.algorithms.MAML(model, lr=1e-3, allow_unused=False)
optim = torch.optim.SGD(maml.parameters(), lr=0.001)
data = torch.randn(30, 500, 2)

# Adapt and measure loss
learner = maml.clone()
for step in range(N_STEPS):
pred, hidden = learner(data)
loss = pred.norm(p=2)
learner.adapt(loss)
pred, _ = learner(data)
first_loss = pred.norm(p=2)

# Take an optimization step
optim.zero_grad()
first_loss.backward()
optim.step()
first_loss = first_loss.item()

# Adapt a second time
learner = maml.clone()
for step in range(N_STEPS):
pred, hidden = learner(data)
loss = pred.norm(p=2)
learner.adapt(loss)
pred, _ = learner(data)
second_loss = pred.norm(p=2)
second_loss = second_loss.item()

# Ensure we did better
self.assertTrue(first_loss > second_loss)


def test_module_detach(self):
original_output = self.model(self.input)
original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))
Expand Down

0 comments on commit 6ff649d

Please sign in to comment.