diff --git a/snn/librispeech/datamodule.py b/snn/librispeech/datamodule.py index d7869c0..3a468a1 100644 --- a/snn/librispeech/datamodule.py +++ b/snn/librispeech/datamodule.py @@ -109,7 +109,8 @@ def setup(self, stage: Optional[str] = None): max_sample_length=self.max_sample_length) if self.num_train != 0: - self.training_sampler = torch.utils.data.RandomSampler(self.training_set, replacement=True, + self.training_sampler = torch.utils.data.RandomSampler(self.training_set, # type: ignore + replacement=True, num_samples=self.num_train) self.validation_set = PairDataset(val_dataset, max_sample_length=self.max_sample_length) diff --git a/snn/librispeech/loss/angularproto.py b/snn/librispeech/loss/angularproto.py index d353c1f..bd14465 100644 --- a/snn/librispeech/loss/angularproto.py +++ b/snn/librispeech/loss/angularproto.py @@ -6,7 +6,7 @@ class AngularPrototypicalLoss(nn.Module): - def __init__(self, init_scale=10.0, init_bias=-5.0): + def __init__(self, init_scale=10.0, init_bias=-5.0, **kwargs): super().__init__() self.w = nn.Parameter(torch.as_tensor(init_scale)) self.b = nn.Parameter(torch.as_tensor(init_bias)) diff --git a/snn/librispeech/model/snn_angularproto.py b/snn/librispeech/model/snn_angularproto.py index 8cddcb2..df87037 100644 --- a/snn/librispeech/model/snn_angularproto.py +++ b/snn/librispeech/model/snn_angularproto.py @@ -1,11 +1,10 @@ -from typing import Tuple, List +from typing import Tuple import torch import torch.nn.functional as F import pytorch_lightning as pl from .base import BaseNet from snn.librispeech.loss.angularproto import AngularPrototypicalLoss -from resnet.utils import accuracy class SNNAngularProto(BaseNet): @@ -22,7 +21,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] def training_step(self, # type: ignore[override] batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - support_sets, labels = batch + support_sets, _ = batch #print('support_sets', support_sets.shape) #print('labels', labels.shape) @@ -50,17 +49,6 @@ def training_step(self, # type: ignore[override] support.size(-1)) assert support.size(1) == self.num_ways - # supports = [] - # for shots in support_sets: - # s = [] - # for waveform in shots: - # x = self.spectogram_transform(waveform, augment=self.specaugment) - # print('ResNet (spectro)', x.shape) - # s.append(self.cnn(x)) - # supports.append(torch.cat(s, dim=0)) - - # support = torch.stack(supports, dim=0) - loss, acc = self.training_loss_fn(support) self.log('train_acc_step', acc, on_step=True, on_epoch=False)