Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ssl] add nestrq #2632

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions wenet/ssl/nestrq/nestrq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import math
from typing import Dict, Tuple
import torch
from wenet.ssl.bestrq.bestrq_model import quantize_vector

from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.utils.mask import make_non_pad_mask


class NestRQModel(torch.nn.Module):
""" https://arxiv.org/pdf/2409.08680
"""

def __init__(
self,
encoder: torch.nn.Module,
num_mel_bins: int = 80,
embedding_dim: int = 16,
num_embeddings: int = 8192,
num_codebooks: int = 1,
n_subsequent: int = 1,
out_bias: bool = False,
) -> None:
super().__init__()
self.num_codebooks = num_codebooks
self.num_embeddings = num_embeddings

# encoder
self.encoder = encoder
# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(n_subsequent, self.num_codebooks,
self.encoder.output_size(), num_embeddings))
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
self.out_bias = out_bias
if self.out_bias:
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
torch.empty(n_subsequent, self.num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)

# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
input_dim = num_mel_bins * self.stride

# random projectoin
self.projection = torch.nn.parameter.Parameter(
torch.empty(input_dim, embedding_dim * self.num_codebooks),
requires_grad=False,
)
torch.nn.init.xavier_uniform_(self.projection)
self.norm = torch.nn.LayerNorm(self.stack_frames * num_mel_bins,
eps=1e-6,
elementwise_affine=False,
bias=False)
# Section: 1B
self.n_subsequent = n_subsequent

# codebook
# [num_embeddings, num_codebooks, num_embeddings] means
# [C, G, D] see quantize_vector
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(num_embeddings, self.num_codebooks, embedding_dim),
requires_grad=False,
)
torch.nn.init.normal_(self.embeddings)
self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) +
1e-8)

# force reset encoder papameter
self.reset_encoder_parameter()

def reset_encoder_parameter(self):

def _reset_parameter(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.trunc_normal_(module.weight.data,
mean=0.0,
std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups /
(module.in_channels * module.kernel_size[0]))
torch.nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, torch.Tensor):
torch.nn.init.trunc_normal_(module)
else:
raise NotImplementedError("other module not support now")

encoders = self.encoder.encoders
for _, layer in enumerate(encoders):
self_attn = layer.self_attn
_reset_parameter(self_attn.linear_q)
_reset_parameter(self_attn.linear_k)
_reset_parameter(self_attn.linear_v)
_reset_parameter(self_attn.linear_out)
if isinstance(self_attn, RelPositionMultiHeadedAttention):
_reset_parameter(self_attn.pos_bias_u)
_reset_parameter(self_attn.pos_bias_v)
if isinstance(layer, ConformerEncoderLayer):
conv1, conv2 = (layer.conv_module.pointwise_conv1,
layer.conv_module.depthwise_conv)
_reset_parameter(conv1)
_reset_parameter(conv2)

def forward(
self,
batch: Dict,
device: torch.device,
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
input = xs

# 1 stack fbank, out_mask is for compute loss (NPT)
stack_input, stack_out_mask = self._stack_features(input, xs_lens)

# 2 get nearest embedding
target_ids = self._nearest_embedding_idx(stack_input)
target_ids = target_ids[:, :stack_out_mask.size(1), :]
target_ids = target_ids.unfold(1, size=self.n_subsequent,
step=1).transpose(-1,
-2) # (B,T,-1, vocab)

# 3 forward xxx-formaer block and its subsampling layer
# TODO(mddct): encoder causal mask
out, out_mask = self.encoder(xs, xs_lens)

# 4 get logits
out = out.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [1, n_subsequent, num_codebooks, dim, num_embeddings]
out = torch.matmul(
out,
top_n_out) # [B, n_subsequent, num_codebooks, T', num_embeddings]
if self.out_bias:
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(3)

# shift input and target for next token prediction
out = out[:, :, :, :target_ids.size(1), :]
target_ids = target_ids[:, 1:, :, :]
masks = out_mask.squeeze(1) * stack_out_mask
masks = masks[:, 1:]

# 5 compute loss
loss = self._compute_loss(out, target_ids, mask=masks)

# 6 other info: num codes used in batch, unique num codes used in batch
num_codes = masks.sum() * self.num_codebooks
uniq_num_codes = torch.tensor(
torch.unique(target_ids * masks.unsqueeze(2)).numel()).detach()
ids_corr = out.argmax(dim=-1, keepdim=False).transpose(1,
2) == target_ids
codes_acc = (ids_corr * masks.unsqueeze(2)).sum() / num_codes
return {
"codes_acc": codes_acc,
"loss": loss,
"num_codes": num_codes,
"uniq_num_codes": uniq_num_codes,
"th_accuracy": codes_acc,
}

def _stack_features(
self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
mask = make_non_pad_mask(input_lens)
mask_stride = mask.unfold(
1,
size=self.stack_frames,
step=self.stride,
)
subsampline_mask, _ = torch.min(mask_stride, dim=-1)

stack_input = input.unfold(1, size=self.stack_frames, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)

return self.norm(stack_input), subsampline_mask

def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
logits = input.contiguous().permute(
(0, 3, 1, 2, 4)).view(-1, input.size(-1))
mask = mask.unsqueeze(2).unsqueeze(2).repeat(1, 1, self.n_subsequent,
self.num_codebooks)
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
reduction='none',
)
loss = (loss * mask.view(-1)).sum() / mask.sum()
return loss

def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = torch.matmul(xs, self.projection.to(xs.device))
xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, codebooks)

return codes.reshape(B, T, -1) # [B, T, num_codebooks]
Loading