Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spiking RNN bug #101

Closed
1y9y9l4 opened this issue Aug 25, 2021 · 6 comments
Closed

Spiking RNN bug #101

1y9y9l4 opened this issue Aug 25, 2021 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@1y9y9l4
Copy link

1y9y9l4 commented Aug 25, 2021

There is a bug in file 'spikingjelly/clock_driven/rnn.py' line 471 and 482.
The RNN cell in PyTorch will return the 'output' and hidden state 'h_n', see:
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU

I think line 471 & 482 should be modified as:
new_states_list[0], _ = self.cells[0](x[t], states_list[0])
new_states_list[i], _ = self.cells[i](y, states_list[i])

@fangwei123456 fangwei123456 self-assigned this Aug 26, 2021
@fangwei123456 fangwei123456 added the bug Something isn't working label Aug 26, 2021
@fangwei123456
Copy link
Owner

Thanks for you debug!

states_num() will return the number of hidden states. For example, it returns 1 for SpikingVanillaRNN:

Then its forward function also returns a tensor, rather than a tuple:

return self.surrogate_function(self.linear_ih(x) + self.linear_hh(h))

So, in Line 471:

new_states_list[0] = self.cells[0](x[t], states_list[0])

if self.states_num() == 1:
    new_states_list[0] = self.cells[0](x[t], states_list[0])

@fangwei123456
Copy link
Owner

When states_num() > 1, then the forward function will return a tuple:

The shape[0] of states_list is states_num():

states_list = torch.zeros(size=[self.states_num(), self.num_layers, batch_size, self.hidden_size]).to(

Then the tuple is concatenated by torch.stack and converted to tensor:

new_states_list[:, i] = torch.stack(self.cells[i](y, states_list[:, i]))

@1y9y9l4
Copy link
Author

1y9y9l4 commented Aug 26, 2021

I understand what you mean. But this bug occurs in the GRU layer. Used 'pytorch==1.6.0' and 'spikingjelly==0.0.0.4', the following code reports an error:
import torch
from spikingjelly.clock_driven import rnn
rsnn = rnn.SpikingGRU(4, 64, 1)
x = torch.zeros(10, 32, 4)
y = rsnn(x)

Another thing that puzzles me is that in any case, RNN always generates LSTM Cells: (line306 & 308)

cells.append(SpikingLSTMCell(self.input_size, self.hidden_size, self.bias, *args, **kwargs))

Should the correct writing be self.base_cell()(line297) instead of SpikingLSTMCell?

@fangwei123456
Copy link
Owner

fangwei123456 commented Aug 26, 2021

Should the correct writing be self.base_cell()(line297) instead of SpikingLSTMCell?

Yes, I will fix it.

@fangwei123456
Copy link
Owner

I understand what you mean. But this bug occurs in the GRU layer. Used 'pytorch==1.6.0' and 'spikingjelly==0.0.0.4', the following code reports an error:
import torch
from spikingjelly.clock_driven import rnn
rsnn = rnn.SpikingGRU(4, 64, 1)
x = torch.zeros(10, 32, 4)
y = rsnn(x)

This problem is caused by

cells.append(SpikingLSTMCell(self.input_size, self.hidden_size, self.bias, *args, **kwargs))

Now the code can run correctly.

@1y9y9l4
Copy link
Author

1y9y9l4 commented Aug 26, 2021

Thank you very much. It solved my problem perfectly.

@1y9y9l4 1y9y9l4 closed this as completed Aug 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants