Skip to content

Commit 555cab0

Browse files
Merge pull request #482 from BindsNET/accuracy_fix
fixes accuracies
2 parents ead5521 + a168592 commit 555cab0

File tree

3 files changed

+6
-12
lines changed

3 files changed

+6
-12
lines changed

bindsnet/network/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
191191

192192
def compute_window(self, s: torch.Tensor) -> torch.Tensor:
193193
# language=rst
194-
""""""
194+
""" """
195195

196196
if self.s_w == None:
197197
# Construct a matrix of shape batch size * window size * dimension of layer

examples/mnist/SOM_LM-SNNs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@
341341

342342
pbar = tqdm(total=n_test)
343343
for step, batch in enumerate(test_dataset):
344-
if step > n_test:
344+
if step >= n_test:
345345
break
346346
# Get next input sample.
347347
inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}

examples/mnist/eth_mnist.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
from bindsnet.models import DiehlAndCook2015
1515
from bindsnet.network.monitors import Monitor
1616
from bindsnet.utils import get_square_weights, get_square_assignments
17-
from bindsnet.evaluation import (
18-
all_activity,
19-
proportion_weighting,
20-
assign_labels,
21-
)
17+
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
2218
from bindsnet.analysis.plotting import (
2319
plot_input,
2420
plot_spikes,
@@ -168,8 +164,8 @@
168164
# Train the network.
169165
print("\nBegin training.\n")
170166
start = t()
171-
labels = []
172167
for epoch in range(n_epochs):
168+
labels = []
173169

174170
if epoch % progress_interval == 0:
175171
print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
@@ -194,9 +190,7 @@
194190

195191
# Get network predictions.
196192
all_activity_pred = all_activity(
197-
spikes=spike_record,
198-
assignments=assignments,
199-
n_labels=n_classes,
193+
spikes=spike_record, assignments=assignments, n_labels=n_classes
200194
)
201195
proportion_pred = proportion_weighting(
202196
spikes=spike_record,
@@ -312,7 +306,7 @@
312306

313307
pbar = tqdm(total=n_test)
314308
for step, batch in enumerate(test_dataset):
315-
if step > n_test:
309+
if step >= n_test:
316310
break
317311
# Get next input sample.
318312
inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}

0 commit comments

Comments
 (0)