diff --git a/wenet/ssl/nestrq/nestrq_model.py b/wenet/ssl/nestrq/nestrq_model.py index 8d7d06a0d..09f05fbdb 100644 --- a/wenet/ssl/nestrq/nestrq_model.py +++ b/wenet/ssl/nestrq/nestrq_model.py @@ -5,6 +5,7 @@ from wenet.transformer.attention import RelPositionMultiHeadedAttention from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.utils.mask import make_non_pad_mask class NestRQModel(torch.nn.Module): @@ -18,6 +19,7 @@ def __init__( embedding_dim: int = 16, num_embeddings: int = 8192, num_codebooks: int = 1, + n_subsequent: int = 1, out_bias: bool = False, ) -> None: super().__init__() @@ -28,13 +30,13 @@ def __init__( self.encoder = encoder # n softmax self.encoder_top_n_out = torch.nn.parameter.Parameter( - torch.empty(self.num_codebooks, self.encoder.output_size(), - num_embeddings)) + torch.empty(n_subsequent, self.num_codebooks, + self.encoder.output_size(), num_embeddings)) torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02) self.out_bias = out_bias if self.out_bias: self.encoder_top_n_out_bias = torch.nn.parameter.Parameter( - torch.empty(self.num_codebooks, num_embeddings)) + torch.empty(n_subsequent, self.num_codebooks, num_embeddings)) torch.nn.init.zeros_(self.encoder_top_n_out_bias) # stack input: eg: fbank @@ -52,6 +54,8 @@ def __init__( eps=1e-6, elementwise_affine=False, bias=False) + # Section: 1B + self.n_subsequent = n_subsequent # codebook # [num_embeddings, num_codebooks, num_embeddings] means @@ -114,28 +118,31 @@ def forward( # 1 stack fbank, out_mask is for compute loss (NPT) stack_input, stack_out_mask = self._stack_features(input, xs_lens) - masked_xs = xs # 2 get nearest embedding target_ids = self._nearest_embedding_idx(stack_input) - target_ids = target_ids[:, :out_mask.size(1), :] + target_ids = target_ids[:, :stack_out_mask.size(1), :] + target_ids = target_ids.unfold(1, size=self.n_subsequent, + step=1).transpose(-1, + -2) # (B,T,-1, vocab) # 3 forward xxx-formaer block and its subsampling layer # TODO(mddct): encoder causal mask - out, out_mask = self.encoder(masked_xs, xs_lens) + out, out_mask = self.encoder(xs, xs_lens) # 4 get logits - out = out.unsqueeze(1) # [B, 1, T', dim] + out = out.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T', dim] top_n_out = self.encoder_top_n_out.unsqueeze( - 0) # [1, num_codebooks, dim, num_embeddings] - out = torch.matmul(out, - top_n_out) # [B, num_codebooks, T', num_embeddings] + 0) # [1, n_subsequent, num_codebooks, dim, num_embeddings] + out = torch.matmul( + out, + top_n_out) # [B, n_subsequent, num_codebooks, T', num_embeddings] if self.out_bias: - out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) + out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(3) # shift input and target for next token prediction - out = out[:, :, :-1:] - target_ids = target_ids[:, 1:, :] + out = out[:, :, :, :target_ids.size(1), :] + target_ids = target_ids[:, 1:, :, :] masks = out_mask.squeeze(1) * stack_out_mask masks = masks[:, 1:] @@ -160,7 +167,6 @@ def forward( def _stack_features( self, input: torch.Tensor, input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - mask = make_non_pad_mask(input_lens) mask_stride = mask.unfold( 1, @@ -178,8 +184,10 @@ def _stack_features( def _compute_loss(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1)) - mask = mask.unsqueeze(2).repeat(1, 1, self.num_codebooks) + logits = input.contiguous().permute( + (0, 3, 1, 2, 4)).view(-1, input.size(-1)) + mask = mask.unsqueeze(2).unsqueeze(2).repeat(1, 1, self.n_subsequent, + self.num_codebooks) loss = torch.nn.functional.cross_entropy( logits, target.contiguous().view(-1),