Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions examples/mnist/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

# Create simple Torch NN
network = Network(dt=dt)
inpt = Input(784, shape=(1, 28, 28))
network.add_layer(inpt, name="I")
Expand All @@ -84,6 +85,7 @@
network.add_connection(C1, source="I", target="O")
network.add_connection(C2, source="O", target="O")

# Monitors for visualizing activity
spikes = {}
for l in network.layers:
spikes[l] = Monitor(network.layers[l], ["s"], time=time, device=device)
Expand All @@ -101,7 +103,7 @@
dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("..", "..", "data", "MNIST"),
root=os.path.join("..", "data", "MNIST"),
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
Expand All @@ -123,19 +125,26 @@
)

# Run training data on reservoir computer and store (spikes per neuron, label) per example.
# Note: Because this is a reservoir network, no adjustments of neuron parameters occurs in this phase.
n_iters = examples
training_pairs = []
pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
if i > n_iters:
break

# Extract & resize the MNIST samples image data for training
# int(time / dt) -> length of spike train
# 28 x 28 -> size of sample
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
label = dataPoint["label"]
pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters))

# Run network on sample image
network.run(inputs={"I": datum}, time=time, input_time_dim=1)
training_pairs.append([spikes["O"].get("s").sum(0), label])

# Plot spiking activity using monitors
if plot:

inpt_axes, inpt_ims = plot_input(
Expand Down Expand Up @@ -165,6 +174,7 @@


# Define logistic regression model using PyTorch.
# These neurons will take the reservoirs output as its input, and be trained to classify the images.
class NN(nn.Module):
def __init__(self, input_size, num_classes):
super(NN, self).__init__()
Expand All @@ -189,14 +199,26 @@ def forward(self, x):
pbar = tqdm(enumerate(range(n_epochs)))
for epoch, _ in pbar:
avg_loss = 0

# Extract spike outputs from reservoir for a training sample
# i -> Loop index
# s -> Reservoir output spikes
# l -> Image label
for i, (s, l) in enumerate(training_pairs):
# Forward + Backward + Optimize

# Reset gradients to 0
optimizer.zero_grad()

# Run spikes through logistic regression model
outputs = model(s)

# Calculate MSE
label = torch.zeros(1, 1, 10).float().to(device)
label[0, 0, l] = 1.0
loss = criterion(outputs.view(1, 1, -1), label)
avg_loss += loss.data

# Optimize parameters
loss.backward()
optimizer.step()

Expand All @@ -205,17 +227,19 @@ def forward(self, x):
% (epoch + 1, n_epochs, avg_loss / len(training_pairs))
)

# Run same simulation on reservoir with testing data instead of training data
# (see training section for intuition)
n_iters = examples
test_pairs = []
pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
if i > n_iters:
break
datum = dataPoint["encoded_image"].view(time, 1, 1, 28, 28).to(device)
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
label = dataPoint["label"]
pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters))

network.run(inputs={"I": datum}, time=250, input_time_dim=1)
network.run(inputs={"I": datum}, time=time, input_time_dim=1)
test_pairs.append([spikes["O"].get("s").sum(0), label])

if plot:
Expand All @@ -227,12 +251,12 @@ def forward(self, x):
ims=inpt_ims,
)
spike_ims, spike_axes = plot_spikes(
{layer: spikes[layer].get("s").view(-1, 250) for layer in spikes},
{layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
axes=spike_axes,
ims=spike_ims,
)
voltage_ims, voltage_axes = plot_voltages(
{layer: voltages[layer].get("v").view(-1, 250) for layer in voltages},
{layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
ims=voltage_ims,
axes=voltage_axes,
)
Expand All @@ -244,7 +268,7 @@ def forward(self, x):
plt.pause(1e-8)
network.reset_state_variables()

# Test the Model
# Test model with previously trained logistic regression classifier
correct, total = 0, 0
for s, label in test_pairs:
outputs = model(s)
Expand Down