Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl/bestrq] happy ending---stable training #2614

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions wenet/ssl/bestrq/bestrq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices_v2
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import make_non_pad_mask, make_pad_mask
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer

Expand Down Expand Up @@ -96,12 +96,7 @@ def __init__(
# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
input_dim = num_mel_bins * self.stack_frames

# norm input
self.norm = torch.nn.LayerNorm(
input_dim, eps=norm_epsilon, elementwise_affine=False
) if self.stack_frames > 1 else torch.nn.Identity()
input_dim = num_mel_bins * self.stride

# random projectoin
self.projection = torch.nn.parameter.Parameter(
Expand Down Expand Up @@ -177,11 +172,12 @@ def forward(
xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens)

# 2.0 stack fbank
unmasked_xs = self._stack_features(input)
unmasked_xs = self._stack_features(input, xs_lens)
masked_xs = xs

# 2.1 get nearest embedding
target_ids = self._nearest_embedding_idx(unmasked_xs)
target_ids = target_ids[:, :code_ids_mask.size(1), :]

# 3 forward xxx-formaer block and its subsampling layer
out, out_mask = self.encoder(masked_xs, xs_lens)
Expand Down Expand Up @@ -258,30 +254,40 @@ def _apply_mask_signal(
xs = torch.where(masks_expand, mask_emb, input)
return xs, subsampling_mask

def _stack_features(self, input: torch.Tensor) -> torch.Tensor:
def _stack_features(self, input: torch.Tensor,
input_lens: torch.Tensor) -> torch.Tensor:

stack_input = input.unfold(1, size=self.stack_frames, step=self.stride)
stack_input = input.unfold(1, size=self.stride, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)

return stack_input
# NOTE(Mddct): important!!!
# norm stack features
mask = make_non_pad_mask(input_lens)
stack_mask = mask.unfold(1, size=self.stride, step=self.stride)
stack_mask, _ = torch.min(stack_mask, dim=-1)

stack_input = stack_input * stack_mask.unsqueeze(2)
mean = stack_input.sum(1, keepdim=True) / stack_mask.sum(
dim=1, keepdim=True).unsqueeze(1)
std = torch.sqrt(((stack_input - mean)**2).sum(dim=1, keepdim=True) /
stack_mask.sum(dim=1, keepdim=True).unsqueeze(1))
norm_stack_input = (stack_input - mean) / (std + 1e-5)
return norm_stack_input

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
log_probs = torch.log_softmax(input, dim=-1).transpose(
1, 2) # [B, T', num_codebooks, num_embeddings]

per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze(
3) # [B, T', num_codebooks]

numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2))
denominator = torch.sum(mask) + 1e-5
loss = numerator / (denominator * self.num_codebooks)
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
reduction='none',
)
loss = (loss * mask.view(-1)).sum() / mask.sum()
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = self.norm(xs)
xs = torch.matmul(xs, self.projection.to(xs.device))
xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
Expand Down
Loading