Skip to content

Commit

Permalink
Adapt transpose to correspond to numpy default
Browse files Browse the repository at this point in the history
  • Loading branch information
TCord authored and Alexander Werning committed Feb 12, 2024
1 parent c96cfb8 commit 267f0e7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions padertorch/contrib/tcl/speaker_embeddings/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def forward(self, embeddings, labels):
logits = self.fc(logits)

if self.loss_type == 'aam':
numerator = self.s * (torch.diagonal(logits.transpose(0, 1)[labels]) - self.m)
numerator = self.s * (torch.diagonal(logits.transpose(1, 0)[labels]) - self.m)
elif self.loss_type == 'arcface':
numerator = self.s * torch.cos(torch.acos(
torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)) + self.m)
torch.clamp(torch.diagonal(logits.transpose(1, 0)[labels]), -1. + self.eps, 1 - self.eps)) + self.m)
elif self.loss_type == 'sphereface':
numerator = self.s * torch.cos(self.m * torch.acos(
torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)))
torch.clamp(torch.diagonal(logits.transpose(1, 0)[labels]), -1. + self.eps, 1 - self.eps)))
else:
return NotImplementedError
excl = torch.cat([torch.cat((logits[i, :y], logits[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels)],
Expand Down

0 comments on commit 267f0e7

Please sign in to comment.