Skip to content

Commit

Permalink
Merge pull request #2 from synsense/layer_conversion
Browse files Browse the repository at this point in the history
Adding helper scripts that enable automatic conversion between Sinabs and EXODUS backends
  • Loading branch information
bauerfe authored Nov 1, 2022
2 parents cc42ba4 + b9a7614 commit cc7a56b
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ def run(self):
)
],
cmdclass=cmdclass,
install_requires=["torch", f"sinabs == {version_major}.*"],
install_requires=["torch", f"sinabs == {version_major}.*, >= 1.1.1"],
)

2 changes: 2 additions & 0 deletions sinabs/exodus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@

from . import _version
__version__ = _version.get_versions()['version']

from . import conversion
62 changes: 62 additions & 0 deletions sinabs/exodus/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import sinabs.layers as sl
import sinabs.exodus.layers as el
import torch
import sinabs

module_map = {
sl.IAF: el.IAF,
sl.IAFSqueeze: el.IAFSqueeze,
sl.LIF: el.LIF,
sl.LIFSqueeze: el.LIFSqueeze,
sl.ExpLeak: el.ExpLeak,
sl.ExpLeakSqueeze: el.ExpLeakSqueeze,
}


def exodus_to_sinabs(model: torch.nn.Module):
"""
Replace all EXODUS layers with the Sinabs equivalent if available.
This can be useful if for example you want to convert your model to a
DynapcnnNetwork or you want to work on a machine without GPU.
All layer attributes will be copied over.
Parameters:
model: The model that contains EXODUS layers.
"""

mapping_list = [
(
exodus_class,
lambda module, replacement=sinabs_class: replacement(**module.arg_dict),
)
for sinabs_class, exodus_class in module_map.items()
]
for class_to_replace, mapper_fn in mapping_list:
model = sinabs.conversion.replace_module(
model, class_to_replace, mapper_fn=mapper_fn
)
return model


def sinabs_to_exodus(model: torch.nn.Module):
"""
Replace all Sinabs layers with EXODUS equivalents if available.
This will typically speed up training by a factor of 2-5.
All layer attributes will be copied over.
Parameters:
model: The model that contains Sinabs layers.
"""

mapping_list = [
(
sinabs_class,
lambda module, replacement=exodus_class: replacement(**module.arg_dict),
)
for sinabs_class, exodus_class in module_map.items()
]
for class_to_replace, mapper_fn in mapping_list:
model = sinabs.conversion.replace_module(
model, class_to_replace, mapper_fn=mapper_fn
)
return model
6 changes: 3 additions & 3 deletions sinabs/exodus/layers/iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def __init__(
record_states=record_states,
decay_early=decay_early,
)
# deactivate tau_mem being learned
self.tau_mem.requires_grad = False
# IAF does not have time constants
self.tau_mem = None

@property
def alpha_mem_calculated(self):
return torch.tensor(1.).to(self.tau_mem.device)
return torch.tensor(1.).to(self.v_mem.device)

@property
def _param_dict(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions sinabs/exodus/layers/lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def forward(self, input_data: torch.Tensor):

return output_full

def __repr__(self):
return "EXODUS " + super().__repr__()

class LIFSqueeze(LIF, SqueezeMixin):
"""
Expand Down
50 changes: 50 additions & 0 deletions tests/test_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sinabs.exodus.layers as el
import sinabs.layers as sl
import torch.nn as nn
from sinabs.exodus import conversion


def test_sinabs_to_exodus_layer_replacement():
batch_size = 12
sinabs_model = nn.Sequential(
nn.Conv2d(2, 8, 5, 1),
sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
sl.SumPool2d(2, 2),
nn.Conv2d(8, 16, 3, 1),
sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-2),
sl.SumPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, 10),
)

exodus_model = conversion.sinabs_to_exodus(sinabs_model)

assert type(sinabs_model[1]) == sl.IAFSqueeze
assert type(exodus_model[1]) == el.IAFSqueeze
assert len(sinabs_model) == len(exodus_model)
assert exodus_model[1].min_v_mem == sinabs_model[1].min_v_mem
assert exodus_model[4].min_v_mem == sinabs_model[4].min_v_mem
assert exodus_model[1].batch_size == sinabs_model[1].batch_size


def test_exodus_to_sinabs_layer_replacement():
batch_size = 12
exodus_model = nn.Sequential(
nn.Conv2d(2, 8, 5, 1),
el.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
sl.SumPool2d(2, 2),
nn.Conv2d(8, 16, 3, 1),
el.IAFSqueeze(batch_size=batch_size, min_v_mem=-2),
sl.SumPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, 10),
)

sinabs_model = conversion.exodus_to_sinabs(exodus_model)

assert type(sinabs_model[1]) == sl.IAFSqueeze
assert type(exodus_model[1]) == el.IAFSqueeze
assert len(exodus_model) == len(sinabs_model)
assert exodus_model[1].min_v_mem == sinabs_model[1].min_v_mem
assert exodus_model[4].min_v_mem == sinabs_model[4].min_v_mem
assert exodus_model[1].batch_size == sinabs_model[1].batch_size
1 change: 1 addition & 0 deletions tests/test_exp_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_leaky_basic():
assert input_current.shape == membrane_output.shape
assert torch.isnan(membrane_output).sum() == 0
assert membrane_output.sum() > 0
assert "EXODUS" in layer.__repr__()


def test_leaky_basic_early_decay():
Expand Down
8 changes: 7 additions & 1 deletion tests/test_iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_iaf_basic():
assert input_current.shape == spike_output.shape
assert torch.isnan(spike_output).sum() == 0
assert spike_output.sum() > 0
assert "EXODUS" in layer.__repr__()


def test_iaf_squeeze():
Expand Down Expand Up @@ -119,6 +120,7 @@ def test_sinabs_model():
backend="sinabs",
n_input_channels=n_input_channels,
n_output_classes=n_output_classes,
batch_size=batch_size,
).cuda()
input_data = torch.rand((batch_size, time_steps, n_input_channels)).cuda() * 1e5
spike_output = model(input_data)
Expand All @@ -135,6 +137,7 @@ def test_exodus_model():
backend="exodus",
n_input_channels=n_input_channels,
n_output_classes=n_output_classes,
batch_size=batch_size,
).cuda()
input_data = torch.rand((batch_size, time_steps, n_input_channels)).cuda() * 1e5
spike_output = model(input_data)
Expand All @@ -151,11 +154,13 @@ def test_exodus_sinabs_model_equal_output():
backend="sinabs",
n_input_channels=n_input_channels,
n_output_classes=n_output_classes,
batch_size=batch_size,
).cuda()
exodus_model = SNN(
backend="exodus",
n_input_channels=n_input_channels,
n_output_classes=n_output_classes,
batch_size=batch_size,
).cuda()
# make sure the weights for linear layers are the same
for (sinabs_layer, exodus_layer) in zip(
Expand Down Expand Up @@ -258,6 +263,7 @@ class SNN(nn.Module):
def __init__(
self,
backend,
batch_size,
n_input_channels=16,
n_output_classes=10,
threshold=1.0,
Expand All @@ -268,7 +274,7 @@ def __init__(
n_input_channels=n_input_channels, n_output_classes=n_output_classes
)
self.network = from_model(
ann, backend=backend, spike_threshold=threshold, min_v_mem=min_v_mem
ann, backend=backend, spike_threshold=threshold, min_v_mem=min_v_mem, batch_size=batch_size
).spiking_model

def reset_states(self):
Expand Down
1 change: 1 addition & 0 deletions tests/test_lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_lif_basic():
assert input_current.shape == spike_output.shape
assert torch.isnan(spike_output).sum() == 0
assert spike_output.sum() > 0
assert "EXODUS" in layer.__repr__()


def test_lif_squeeze():
Expand Down

0 comments on commit cc7a56b

Please sign in to comment.