-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from synsense/layer_conversion
Adding helper scripts that enable automatic conversion between Sinabs and EXODUS backends
- Loading branch information
Showing
9 changed files
with
129 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,5 @@ | |
|
||
from . import _version | ||
__version__ = _version.get_versions()['version'] | ||
|
||
from . import conversion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import sinabs.layers as sl | ||
import sinabs.exodus.layers as el | ||
import torch | ||
import sinabs | ||
|
||
module_map = { | ||
sl.IAF: el.IAF, | ||
sl.IAFSqueeze: el.IAFSqueeze, | ||
sl.LIF: el.LIF, | ||
sl.LIFSqueeze: el.LIFSqueeze, | ||
sl.ExpLeak: el.ExpLeak, | ||
sl.ExpLeakSqueeze: el.ExpLeakSqueeze, | ||
} | ||
|
||
|
||
def exodus_to_sinabs(model: torch.nn.Module): | ||
""" | ||
Replace all EXODUS layers with the Sinabs equivalent if available. | ||
This can be useful if for example you want to convert your model to a | ||
DynapcnnNetwork or you want to work on a machine without GPU. | ||
All layer attributes will be copied over. | ||
Parameters: | ||
model: The model that contains EXODUS layers. | ||
""" | ||
|
||
mapping_list = [ | ||
( | ||
exodus_class, | ||
lambda module, replacement=sinabs_class: replacement(**module.arg_dict), | ||
) | ||
for sinabs_class, exodus_class in module_map.items() | ||
] | ||
for class_to_replace, mapper_fn in mapping_list: | ||
model = sinabs.conversion.replace_module( | ||
model, class_to_replace, mapper_fn=mapper_fn | ||
) | ||
return model | ||
|
||
|
||
def sinabs_to_exodus(model: torch.nn.Module): | ||
""" | ||
Replace all Sinabs layers with EXODUS equivalents if available. | ||
This will typically speed up training by a factor of 2-5. | ||
All layer attributes will be copied over. | ||
Parameters: | ||
model: The model that contains Sinabs layers. | ||
""" | ||
|
||
mapping_list = [ | ||
( | ||
sinabs_class, | ||
lambda module, replacement=exodus_class: replacement(**module.arg_dict), | ||
) | ||
for sinabs_class, exodus_class in module_map.items() | ||
] | ||
for class_to_replace, mapper_fn in mapping_list: | ||
model = sinabs.conversion.replace_module( | ||
model, class_to_replace, mapper_fn=mapper_fn | ||
) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import sinabs.exodus.layers as el | ||
import sinabs.layers as sl | ||
import torch.nn as nn | ||
from sinabs.exodus import conversion | ||
|
||
|
||
def test_sinabs_to_exodus_layer_replacement(): | ||
batch_size = 12 | ||
sinabs_model = nn.Sequential( | ||
nn.Conv2d(2, 8, 5, 1), | ||
sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1), | ||
sl.SumPool2d(2, 2), | ||
nn.Conv2d(8, 16, 3, 1), | ||
sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-2), | ||
sl.SumPool2d(2, 2), | ||
nn.Flatten(), | ||
nn.Linear(64, 10), | ||
) | ||
|
||
exodus_model = conversion.sinabs_to_exodus(sinabs_model) | ||
|
||
assert type(sinabs_model[1]) == sl.IAFSqueeze | ||
assert type(exodus_model[1]) == el.IAFSqueeze | ||
assert len(sinabs_model) == len(exodus_model) | ||
assert exodus_model[1].min_v_mem == sinabs_model[1].min_v_mem | ||
assert exodus_model[4].min_v_mem == sinabs_model[4].min_v_mem | ||
assert exodus_model[1].batch_size == sinabs_model[1].batch_size | ||
|
||
|
||
def test_exodus_to_sinabs_layer_replacement(): | ||
batch_size = 12 | ||
exodus_model = nn.Sequential( | ||
nn.Conv2d(2, 8, 5, 1), | ||
el.IAFSqueeze(batch_size=batch_size, min_v_mem=-1), | ||
sl.SumPool2d(2, 2), | ||
nn.Conv2d(8, 16, 3, 1), | ||
el.IAFSqueeze(batch_size=batch_size, min_v_mem=-2), | ||
sl.SumPool2d(2, 2), | ||
nn.Flatten(), | ||
nn.Linear(64, 10), | ||
) | ||
|
||
sinabs_model = conversion.exodus_to_sinabs(exodus_model) | ||
|
||
assert type(sinabs_model[1]) == sl.IAFSqueeze | ||
assert type(exodus_model[1]) == el.IAFSqueeze | ||
assert len(exodus_model) == len(sinabs_model) | ||
assert exodus_model[1].min_v_mem == sinabs_model[1].min_v_mem | ||
assert exodus_model[4].min_v_mem == sinabs_model[4].min_v_mem | ||
assert exodus_model[1].batch_size == sinabs_model[1].batch_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters