Skip to content

Commit

Permalink
Misc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
vjoki committed Jan 13, 2021
1 parent 06e775b commit 13fbf07
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 16 deletions.
3 changes: 2 additions & 1 deletion snn/librispeech/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def setup(self, stage: Optional[str] = None):
max_sample_length=self.max_sample_length)

if self.num_train != 0:
self.training_sampler = torch.utils.data.RandomSampler(self.training_set, replacement=True,
self.training_sampler = torch.utils.data.RandomSampler(self.training_set, # type: ignore
replacement=True,
num_samples=self.num_train)

self.validation_set = PairDataset(val_dataset, max_sample_length=self.max_sample_length)
Expand Down
2 changes: 1 addition & 1 deletion snn/librispeech/loss/angularproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class AngularPrototypicalLoss(nn.Module):
def __init__(self, init_scale=10.0, init_bias=-5.0):
def __init__(self, init_scale=10.0, init_bias=-5.0, **kwargs):
super().__init__()
self.w = nn.Parameter(torch.as_tensor(init_scale))
self.b = nn.Parameter(torch.as_tensor(init_bias))
Expand Down
16 changes: 2 additions & 14 deletions snn/librispeech/model/snn_angularproto.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Tuple, List
from typing import Tuple
import torch
import torch.nn.functional as F
import pytorch_lightning as pl

from .base import BaseNet
from snn.librispeech.loss.angularproto import AngularPrototypicalLoss
from resnet.utils import accuracy


class SNNAngularProto(BaseNet):
Expand All @@ -22,7 +21,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
def training_step(self, # type: ignore[override]
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int) -> torch.Tensor:
support_sets, labels = batch
support_sets, _ = batch
#print('support_sets', support_sets.shape)
#print('labels', labels.shape)

Expand Down Expand Up @@ -50,17 +49,6 @@ def training_step(self, # type: ignore[override]
support.size(-1))
assert support.size(1) == self.num_ways

# supports = []
# for shots in support_sets:
# s = []
# for waveform in shots:
# x = self.spectogram_transform(waveform, augment=self.specaugment)
# print('ResNet (spectro)', x.shape)
# s.append(self.cnn(x))
# supports.append(torch.cat(s, dim=0))

# support = torch.stack(supports, dim=0)

loss, acc = self.training_loss_fn(support)

self.log('train_acc_step', acc, on_step=True, on_epoch=False)
Expand Down

0 comments on commit 13fbf07

Please sign in to comment.