Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Black Formater

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: psf/black@stable
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ repos:
rev: master
hooks:
- id: black
language_version: python3.8
language_version: python3
2 changes: 1 addition & 1 deletion bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def plot_assignments(

if classes is None:
cbar = plt.colorbar(im, cax=cax, ticks=list(range(-1, 11)))
cbar.ax.set_yticklabels(["none"] + list(range(10)))
cbar.ax.set_yticklabels(["none"] + list(range(11)))
else:
cbar = plt.colorbar(im, cax=cax, ticks=np.arange(-1, len(classes)))
cbar.ax.set_yticklabels(["none"] + list(classes))
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/analysis/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None:
# language=rst
"""
Create and plot movie of weights.

:param ws: Array of shape ``[n_examples, source, target, time]``.
:param sample_every: Sub-sample using this parameter.
"""
Expand Down
4 changes: 2 additions & 2 deletions bindsnet/datasets/alov300.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ALOV300(Dataset):
def __init__(self, root, transform, input_size, download=False):
"""
Class to read the ALOV dataset

:param root: Path to the ALOV folder that contains JPEGImages, Annotations, etc. folders.
:param input_size: The input size of network that is using this data, for rescaling
:param download: Specify whether to download the dataset if it is not present
Expand Down Expand Up @@ -247,7 +247,7 @@ def _download(self):
"""
Downloads the correct dataset based on the given parameters

Relies on self.tag to determine both the name of the folder created for the dataset and for the finding the correct download url.
Relies on self.tag to determine both the name of the folder created for the dataset and for the finding the correct download url.
"""

os.makedirs(self.root)
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/encoding/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class NullEncoder(Encoder):
# language=rst
"""
Pass through of the datum that was input.

.. note::
This is not a real spike encoder. Be careful with the usage of this class.
"""
Expand Down
8 changes: 4 additions & 4 deletions bindsnet/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def assign_labels(
n_neurons = spikes.size(2)

if rates is None:
rates = torch.zeros(n_neurons, n_labels)
rates = torch.zeros((n_neurons, n_labels), device=spikes.device)

# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)
Expand Down Expand Up @@ -112,7 +112,7 @@ def all_activity(
# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)

rates = torch.zeros(n_samples, n_labels)
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
# Count the number of neurons with this label assignment.
n_assigns = torch.sum(assignments == i).float()
Expand Down Expand Up @@ -153,7 +153,7 @@ def proportion_weighting(
# Sum over time dimension (spike ordering doesn't matter).
spikes = spikes.sum(1)

rates = torch.zeros(n_samples, n_labels)
rates = torch.zeros((n_samples, n_labels), device=spikes.device)
for i in range(n_labels):
# Count the number of neurons with this label assignment.
n_assigns = torch.sum(assignments == i).float()
Expand Down Expand Up @@ -191,7 +191,7 @@ def ngram(
"""
predictions = []
for activity in spikes:
score = torch.zeros(n_labels)
score = torch.zeros(n_labels, device=spikes.device)

# Aggregate all of the firing neurons' indices
fire_order = []
Expand Down
4 changes: 3 additions & 1 deletion bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def __init__(
w = w / w.max()
w = (w * self.max_inhib) + self.start_inhib
recurrent_output_conn = Connection(
source=self.layers["Y"], target=self.layers["Y"], w=w,
source=self.layers["Y"],
target=self.layers["Y"],
w=w,
)
self.add_connection(recurrent_output_conn, source="Y", target="Y")

Expand Down
6 changes: 3 additions & 3 deletions bindsnet/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def clone(self) -> "Network":
# language=rst
"""
Returns a cloned network object.

:return: A copy of this network.
"""
virtual_file = tempfile.SpooledTemporaryFile()
Expand Down Expand Up @@ -368,9 +368,9 @@ def run(
unclamp = unclamps.get(l, None)
if unclamp is not None:
if unclamp.ndimension() == 1:
self.layers[l].s[unclamp] = 0
self.layers[l].s[:, unclamp] = 0
else:
self.layers[l].s[unclamp[t]] = 0
self.layers[l].s[:, unclamp[t]] = 0

# Inject voltage to neurons.
inject_v = injects_v.get(l, None)
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def forward(self, x: torch.Tensor) -> None:

# Integrate inputs.
x.masked_fill_(self.refrac_count > 0, 0.0)

# Decrement refractory counters.
self.refrac_count -= self.dt

Expand Down
8 changes: 4 additions & 4 deletions bindsnet/pipeline/environment_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
if isinstance(layer, AbstractInput)
]

self.action = torch.tensor(-1)
self.last_action = torch.tensor(-1)
self.action = torch.tensor(-1, device=self.device)
self.last_action = torch.tensor(-1, device=self.device)
self.action_counter = 0
self.random_action_after = kwargs.get("random_action_after", self.time)

Expand Down Expand Up @@ -169,15 +169,15 @@ def env_step(self) -> Tuple[torch.Tensor, float, bool, Dict]:
self.last_action = self.action
if torch.rand(1) < self.percent_of_random_action:
self.action = torch.randint(
low=0, high=self.spike_record[self.output].shape[-1], size=(1,)
low=0, high=self.env.action_space.n, size=(1,)
)[0]
elif self.action_counter > self.random_action_after:
if self.last_action == 0: # last action was start b
self.action = 1 # next action will be fire b
tqdm.write(f"Fire -> too many times {self.last_action} ")
else:
self.action = torch.randint(
low=2, high=self.spike_record[self.output].shape[-1], size=(1,)
low=0, high=self.env.action_space.n, size=(1,)
)[0]
tqdm.write(f"too many times {self.last_action} ")
else:
Expand Down
69 changes: 47 additions & 22 deletions examples/mnist/SOM_LM-SNNs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
parser.add_argument("--n_neurons", type=int, default=100)
parser.add_argument("--n_epochs", type=int, default=1)
parser.add_argument("--n_test", type=int, default=10000)
parser.add_argument("--n_train", type=int, default=60000)
parser.add_argument("--n_workers", type=int, default=-1)
parser.add_argument("--theta_plus", type=float, default=0.05)
parser.add_argument("--time", type=int, default=250)
Expand All @@ -48,6 +49,7 @@
n_neurons = args.n_neurons
n_epochs = args.n_epochs
n_test = args.n_test
n_train = args.n_train
n_workers = args.n_workers
theta_plus = args.theta_plus
time = args.time
Expand All @@ -66,9 +68,9 @@
torch.cuda.manual_seed_all(seed)
else:
torch.manual_seed(seed)
device = "cpu"
if gpu:
gpu = False
device = 'cpu'

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
Expand Down Expand Up @@ -106,30 +108,34 @@
)

# Record spikes during the simulation.
spike_record = torch.zeros(update_interval, int(time/dt), n_neurons).cpu()
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)

# Neuron assignments and spike proportions.
n_classes = 10
assignments = -torch.ones(n_neurons).cpu()
proportions = torch.zeros(n_neurons, n_classes).cpu()
rates = torch.zeros(n_neurons, n_classes).cpu()
assignments = -torch.ones(n_neurons, device=device)
proportions = torch.zeros((n_neurons, n_classes), device=device)
rates = torch.zeros((n_neurons, n_classes), device=device)

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time/dt))
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time / dt))
network.add_monitor(som_voltage_monitor, name="som_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=int(time/dt))
spikes[layer] = Monitor(
network.layers[layer], state_vars=["s"], time=int(time / dt)
)
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time/dt))
voltages[layer] = Monitor(
network.layers[layer], state_vars=["v"], time=int(time / dt)
)
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
Expand Down Expand Up @@ -166,9 +172,15 @@
dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu
)

for step, batch in enumerate(tqdm(dataloader)):
pbar = tqdm(total=n_train)
for step, batch in enumerate(dataloader):
if step == n_train:
break

# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28).to(device)}
inputs = {
"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
}

if step > 0:
if step % update_inhibation_weights == 0:
Expand All @@ -181,7 +193,7 @@

if step % update_interval == 0:
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels).cpu()
label_tensor = torch.tensor(labels, device=device)

# Get network predictions.
all_activity_pred = all_activity(
Expand Down Expand Up @@ -215,7 +227,8 @@
)
)
tqdm.write(
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
" (best)\n"
% (
accuracy["proportion"][-1],
np.mean(accuracy["proportion"]),
Expand Down Expand Up @@ -247,10 +260,12 @@
if temp_spikes.sum().sum() < 2:
inputs["X"] *= (
poisson(
datum=factor * batch["image"].clamp(min=0), dt=dt, time=int(time/dt)
datum=factor * batch["image"].clamp(min=0),
dt=dt,
time=int(time / dt),
)
.to(device)
.view(int(time/dt), 1, 1, 28, 28)
.view(int(time / dt), 1, 1, 28, 28)
)
factor *= factor
else:
Expand Down Expand Up @@ -292,6 +307,8 @@
plt.pause(1e-8)

network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Train progress: ")
pbar.update()

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
Expand All @@ -313,16 +330,19 @@
accuracy = {"all": 0, "proportion": 0}

# Record spikes during the simulation.
spike_record = torch.zeros(1, int(time/dt), n_neurons)
spike_record = torch.zeros(1, int(time / dt), n_neurons)

# Train the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

for step, batch in enumerate(tqdm(test_dataset)):
pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
if step > n_test:
break
# Get next input sample.
inputs = {"X": batch["encoded_image"].view(int(time/dt), 1, 1, 28, 28)}
inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

Expand All @@ -333,7 +353,7 @@
spike_record[0] = spikes["Y"].get("s").squeeze()

# Convert the array of labels into a tensor
label_tensor = torch.tensor(batch["label"])
label_tensor = torch.tensor(batch["label"], device=device)

# Get network predictions.
all_activity_pred = all_activity(
Expand All @@ -348,13 +368,18 @@

# Compute network accuracy according to available classification strategies.
accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item())
accuracy["proportion"] += float(
torch.sum(label_tensor.long() == proportion_pred).item()
)

network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Test progress: ")
pbar.update()


print("\nAll activity accuracy: %.2f" % (accuracy["all"] / test_dataset.test_labels.shape[0]))
print("Proportion weighting accuracy: %.2f \n" % ( accuracy["proportion"] / test_dataset.test_labels.shape[0]))
print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))


print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
print("Testing complete.\n")
Loading