FastCell Example Fixes, Generalized trainer for both batch_first args#174
FastCell Example Fixes, Generalized trainer for both batch_first args#174SachinG007 wants to merge 7 commits intomicrosoft:masterfrom SachinG007:master
Conversation
adityakusupati
left a comment
There was a problem hiding this comment.
PR #173 trumps over this PR.
|
@oindrilasaha would be good if you can have a look at this PR |
|
@SachinG007 please incorporate the changes from #173 for the fixing of optimizer for FC. Earlier gradient updates were not happening for FC due to some convention mismatch. Please look at this: https://github.com/microsoft/EdgeML/pull/173/files#diff-7b39dde7dda6360cbf530db88f5b9f8dR12-R62 and incorporate it. Also, try to keep PRs for different things separate. When we are fixing fastcell_Example.py related stuff, let us stick to that. Thanks |
|
@adityakusupati , incorporated the changes from PR#173. |
adityakusupati
left a comment
There was a problem hiding this comment.
I have reviewed everything except the bidirectional stuff. @oindrilasaha has to review the bi-directional stuff.
| class SimpleFC(nn.Module): | ||
| def __init__(self, input_size, num_classes, name="SimpleFC"): | ||
| super(SimpleFC, self).__init__() | ||
| self.FC = nn.Parameter(torch.randn([input_size, num_classes])) | ||
| self.FCbias = nn.Parameter(torch.randn([num_classes])) | ||
|
|
||
| def forward(self, input): | ||
| return torch.matmul(input, self.FC) + self.FCbias |
There was a problem hiding this comment.
I am not sure if we should place this here or make a different file or place in rnn.py. Ideally, this is the same as Bonsai.py or ProtoNN.py. We need to make a decision about placing this at the right place.
| self.FCbias = nn.Parameter(torch.randn( | ||
| [self.numClasses])).to(self.device) | ||
|
|
||
| self.simpleFC = SimpleFC(self.FastObj.output_size, self.numClasses).to(self.device) |
There was a problem hiding this comment.
a better name for the instance instead of self.simpleFC (so that we are generic enough to fit the classifier() function.
| TODO: Make this a separate class if needed | ||
| ''' | ||
| return torch.matmul(feats, self.FC) + self.FCbias | ||
| return self.simpleFC(feats) |
There was a problem hiding this comment.
Look at the above comment.
pytorch/edgeml_pytorch/graph/rnn.py
Outdated
| hiddenStates = torch.zeros( | ||
| [input.shape[0], input.shape[1], | ||
| self._RNNCell.output_size]).to(self.device) | ||
| self.RNNCell.output_size]).to(self.device) |
There was a problem hiding this comment.
All the lines after this are about bi-directional stuff. I would defer to @oindrilasaha to check all this stuff. She has to sign off on it as she uses them the most.
There was a problem hiding this comment.
@oindrilasaha any comments for the later section of the code ?
| self.FC = nn.Parameter(torch.randn([input_size, num_classes])) | ||
| self.FCbias = nn.Parameter(torch.randn([num_classes])) |
There was a problem hiding this comment.
@SachinG007 change this to self.weight and self.bias.
|
@SachinG007 any updates on this PR? |
Checked the changes with Aditya