From 06025169e6a32cb59e7c9bcdcd74ee4ba321aa93 Mon Sep 17 00:00:00 2001 From: Shuai Date: Wed, 7 Aug 2024 14:50:40 +0800 Subject: [PATCH 1/6] [feature] add redimnet --- examples/voxceleb/v2/conf/redimnet.yaml | 86 ++ wespeaker/models/redimnet.py | 1037 +++++++++++++++++++++++ 2 files changed, 1123 insertions(+) create mode 100644 examples/voxceleb/v2/conf/redimnet.yaml create mode 100644 wespeaker/models/redimnet.py diff --git a/examples/voxceleb/v2/conf/redimnet.yaml b/examples/voxceleb/v2/conf/redimnet.yaml new file mode 100644 index 00000000..c392b3de --- /dev/null +++ b/examples/voxceleb/v2/conf/redimnet.yaml @@ -0,0 +1,86 @@ +exp_dir: exp/RedimnetB2-emb192-fbank72-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch120 +gpus: "[0,1]" +num_avg: 10 +enable_amp: False # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 120 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 256 + num_workers: 4 + pin_memory: false + prefetch_factor: 4 + drop_last: true + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 100 + max_num_frames: 800 + resample_rate: 16000 + speed_perturb: True + num_frms: 200 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + fbank_args: + num_mel_bins: 72 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ReDimNetB2 +model_init: null +model_args: + feat_dim: 72 + embed_dim: 192 + pooling_func: "ASTP" # TSTP, ASTP, MQMHASTP + two_emb_layer: False + + +projection_args: + project_type: "sphereface2" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.0 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: True + increase_type: "exp" # exp, linear + update_margin: true + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 2.0e-05 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.1 + final_lr: 0.00005 + warm_up_epoch: 6 + warm_from_zero: True + diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py new file mode 100644 index 00000000..5f0a45c5 --- /dev/null +++ b/wespeaker/models/redimnet.py @@ -0,0 +1,1037 @@ +# Copyright (c) 2024 https://github.com/IDRnD/ReDimNet +# 2024 Shuai Wang (wsstriving@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Redimnet in pytorch. + +Reference: +Paper: "Reshape Dimensions Network for Speaker Recognition" +Repo: https://github.com/IDRnD/ReDimNet + +Cite: +@misc{yakovlev2024reshapedimensionsnetworkspeaker, + title={Reshape Dimensions Network for Speaker Recognition}, + author={Ivan Yakovlev and Rostislav Makarov and Andrei Balykin and Pavel Malov and Anton Okhotnikov and Nikita Torgashov}, + year={2024}, + eprint={2407.18223}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + url={https://arxiv.org/abs/2407.18223}, +} +""" +import math + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import wespeaker.models.pooling_layers as pooling_layers + + +MaxPoolNd = {1: nn.MaxPool1d, 2: nn.MaxPool2d} +ConvNd = {1: nn.Conv1d, 2: nn.Conv2d} +BatchNormNd = {1: nn.BatchNorm1d, 2: nn.BatchNorm2d} + + +class to1d(nn.Module): + def forward(self, x): + size = x.size() + bs, c, f, t = tuple(size) + return x.permute((0, 2, 1, 3)).reshape((bs, c * f, t)) + + +class NewGELUActivation(nn.Module): + def forward(self, input): + return ( + 0.5 + * input + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) + * (input + 0.044715 * torch.pow(input, 3.0)) + ) + ) + ) + + +class LayerNorm(nn.Module): + """ + LayerNorm that supports two data formats: channels_last or channels_first. + The ordering of the dimensions in the inputs. + channels_last corresponds to inputs with shape (batch_size, T, channels) + while channels_first corresponds to shape (batch_size, channels, T). + """ + + def __init__(self, C, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(C)) + self.bias = nn.Parameter(torch.zeros(C)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.C = (C,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.C, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + + w = self.weight + b = self.bias + for _ in range(x.ndim - 2): + w = w.unsqueeze(-1) + b = b.unsqueeze(-1) + x = w * x + b + return x + + def extra_repr(self) -> str: + return ", ".join( + [ + f"{k}={v}" + for k, v in { + "C": self.C, + "data_format": self.data_format, + "eps": self.eps, + }.items() + ] + ) + + +class GRU(nn.Module): + def __init__(self, *args, **kwargs): + super(GRU, self).__init__() + self.gru = nn.GRU(*args, **kwargs) + + def forward(self, x): + # x : (bs,C,T) + return self.gru(x.permute((0, 2, 1)))[0].permute((0, 2, 1)) + + +class PosEncConv(nn.Module): + def __init__(self, C, ks, groups=None): + super().__init__() + assert ks % 2 == 1 + self.conv = nn.Conv1d( + C, C, ks, padding=ks // 2, groups=C if groups is None else groups + ) + self.norm = LayerNorm(C, eps=1e-6, data_format="channels_first") + + def forward(self, x): + return x + self.norm(self.conv(x)) + + +class ConvNeXtLikeBlock(nn.Module): + def __init__( + self, + C, + dim=2, + kernel_sizes=[ + (3, 3), + ], + group_divisor=1, + padding="same", + ): + super().__init__() + self.dwconvs = nn.ModuleList( + modules=[ + ConvNd[dim]( + C, + C, + kernel_size=ks, + padding=padding, + groups=C // group_divisor if group_divisor is not None else 1, + ) + for ks in kernel_sizes + ] + ) + self.norm = BatchNormNd[dim](C * len(kernel_sizes)) + self.gelu = nn.GELU() + self.pwconv1 = ConvNd[dim](C * len(kernel_sizes), C, 1) + + def forward(self, x): + skip = x + x = torch.cat([dwconv(x) for dwconv in self.dwconvs], dim=1) + x = self.gelu(self.norm(x)) + x = self.pwconv1(x) + x = skip + x + return x + + +class ConvBlock2d(nn.Module): + def __init__(self, c, f, block_type="convnext_like", group_divisor=1): + super().__init__() + if block_type == "convnext_like": + self.conv_block = ConvNeXtLikeBlock( + c, + dim=2, + kernel_sizes=[(3, 3)], + group_divisor=group_divisor, + padding="same", + ) + elif block_type == "basic_resnet": + self.conv_block = ResBasicBlock( + c, + c, + f, + stride=1, + se_channels=min(64, max(c, 32)), + group_divisor=group_divisor, + use_fwSE=False, + ) + elif block_type == "basic_resnet_fwse": + self.conv_block = ResBasicBlock( + c, + c, + f, + stride=1, + se_channels=min(64, max(c, 32)), + group_divisor=group_divisor, + use_fwSE=True, + ) + else: + raise NotImplementedError() + + def forward(self, x): + return self.conv_block(x) + + +# ------------------------------------------------------------- +# Copy multi-head attention module from hugginface wav2vec2 +# ------------------------------------------------------------- +# Copied from https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/models/wav2vec2/modeling_wav2vec2.py +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 +class MultiHeadAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len, bsz): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + attn_weights = F.softmax(attn_weights, dim=-1) + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + n_state, + n_mlp, + n_head, + channel_last=False, + act_do=0.0, + att_do=0.0, + hid_do=0.0, + ln_eps=1e-6, + ): + + hidden_size = n_state + num_attention_heads = n_head + intermediate_size = n_mlp + activation_dropout = act_do + attention_dropout = att_do + hidden_dropout = hid_do + layer_norm_eps = ln_eps + + super().__init__() + self.channel_last = channel_last + self.attention = MultiHeadAttention( + embed_dim=hidden_size, + num_heads=num_attention_heads, + dropout=attention_dropout, + ) + self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.feed_forward = FeedForward( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation_dropout=activation_dropout, + hidden_dropout=hidden_dropout, + ) + self.final_layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states): + if not self.channel_last: + hidden_states = hidden_states.permute(0, 2, 1) + attn_residual = hidden_states + hidden_states = self.attention(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = hidden_states + if not self.channel_last: + outputs = outputs.permute(0, 2, 1) + return outputs + + +class FeedForward(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + activation_dropout=0.0, + hidden_dropout=0.0, + ): + super().__init__() + self.intermediate_dropout = nn.Dropout(activation_dropout) + self.intermediate_dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = NewGELUActivation() + self.output_dense = nn.Linear(intermediate_size, hidden_size) + self.output_dropout = nn.Dropout(hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class BasicBlock(nn.Module): + """ + Key difference with the BasicBlock in resnet.py: + 1. If use group convolution, conv1 have same number of input/output channels + 2. No stride to downsample + """ + + def __init__( + self, + in_planes, + planes, + stride=1, + group_divisor=4, + ): + super().__init__() + self.conv1 = nn.Conv2d( + in_planes, + in_planes if group_divisor is not None else planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=in_planes // group_divisor if group_divisor is not None else 1, + ) + + # If using group convolution, add point-wise conv to reshape + if group_divisor is not None: + self.conv1pw = nn.Conv2d(in_planes, planes, 1) + else: + self.conv1pw = nn.Identity() + + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + padding=1, + bias=False, + groups=planes // group_divisor if group_divisor is not None else 1, + ) + + # If using group convolution, add point-wise conv to reshape + if group_divisor is not None: + self.conv2pw = nn.Conv2d(planes, planes, 1) + else: + self.conv2pw = nn.Identity() + + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if planes != in_planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + else: + self.shortcut = nn.Identity() + + def forward(self, x): + residual = x + + out = self.conv1pw(self.conv1(x)) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2pw(self.conv2(out)) + out = self.bn2(out) + + out += self.shortcut(residual) + out = self.relu(out) + return out + + +class fwSEBlock(nn.Module): + """ + Squeeze-and-Excitation block + link: https://arxiv.org/pdf/1709.01507.pdf + PyTorch implementation + """ + + def __init__(self, num_freq, num_feats=64): + super(fwSEBlock, self).__init__() + self.squeeze = nn.Linear(num_freq, num_feats) + self.exitation = nn.Linear(num_feats, num_freq) + + self.activation = nn.ReLU() # Assuming ReLU, modify as needed + + def forward(self, inputs): + # [bs, C, F, T] + x = torch.mean(inputs, dim=[1, 3]) + x = self.squeeze(x) + x = self.activation(x) + x = self.exitation(x) + x = torch.sigmoid(x) + # Reshape and apply excitation + x = x[:, None, :, None] + x = inputs * x + return x + + +class ResBasicBlock(nn.Module): + def __init__( + self, + in_planes, + planes, + num_freq, + stride=1, + se_channels=64, + group_divisor=4, + use_fwSE=False, + ): + super().__init__() + self.conv1 = nn.Conv2d( + in_planes, + in_planes if group_divisor is not None else planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + groups=in_planes // group_divisor if group_divisor is not None else 1, + ) + if group_divisor is not None: + self.conv1pw = nn.Conv2d(in_planes, planes, 1) + else: + self.conv1pw = nn.Identity() + + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + padding=1, + bias=False, + groups=planes // group_divisor if group_divisor is not None else 1, + ) + + if group_divisor is not None: + self.conv2pw = nn.Conv2d(planes, planes, 1) + else: + self.conv2pw = nn.Identity() + + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if use_fwSE: + self.se = fwSEBlock(num_freq, se_channels) + else: + self.se = nn.Identity() + + if planes != in_planes: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + else: + self.downsample = nn.Identity() + + def forward(self, x): + residual = x + + out = self.conv1pw(self.conv1(x)) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2pw(self.conv2(out)) + out = self.bn2(out) + out = self.se(out) + + out += self.downsample(residual) + out = self.relu(out) + return out + + +class TimeContextBlock1d(nn.Module): + """ """ + + def __init__( + self, + C, + hC, + pos_ker_sz=59, + block_type="att", + ): + super().__init__() + assert pos_ker_sz + + self.red_dim_conv = nn.Sequential( + nn.Conv1d(C, hC, 1), LayerNorm(hC, eps=1e-6, data_format="channels_first") + ) + + if block_type == "fc": + self.tcm = nn.Sequential( + nn.Conv1d(hC, hC * 2, 1), + LayerNorm(hC * 2, eps=1e-6, data_format="channels_first"), + nn.GELU(), + nn.Conv1d(hC * 2, hC, 1), + ) + elif block_type == "gru": + # Just GRU + self.tcm = nn.Sequential( + GRU( + input_size=hC, + hidden_size=hC, + num_layers=1, + bias=True, + batch_first=False, + dropout=0.0, + bidirectional=True, + ), + nn.Conv1d(2 * hC, hC, 1), + ) + elif block_type == "att": + # Basic Transformer self-attention encoder block + self.tcm = nn.Sequential( + PosEncConv(hC, ks=pos_ker_sz, groups=hC), + TransformerEncoderLayer(n_state=hC, n_mlp=hC * 2, n_head=4), + ) + elif block_type == "conv+att": + # Basic Transformer self-attention encoder block + self.tcm = nn.Sequential( + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[7], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[19], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[31], group_divisor=1, padding="same" + ), + ConvNeXtLikeBlock( + hC, dim=1, kernel_sizes=[59], group_divisor=1, padding="same" + ), + TransformerEncoderLayer(n_state=hC, n_mlp=hC, n_head=4), + ) + else: + raise NotImplementedError() + + self.exp_dim_conv = nn.Conv1d(hC, C, 1) + + def forward(self, x): + skip = x + x = self.red_dim_conv(x) + x = self.tcm(x) + x = self.exp_dim_conv(x) + return skip + x + + +class ReDimNetBone(nn.Module): + def __init__( + self, + F=72, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=[ + # stride, num_blocks, conv_exp, kernel_size, att_block_red + (1, 2, 1, [(3, 3)], None), # 16 + (2, 3, 1, [(3, 3)], None), # 32 + # 64, (72*12 // 8) = 108 - channels in attention block + (3, 4, 1, [(3, 3)], 8), + (2, 5, 1, [(3, 3)], 8), # 128 + (1, 5, 1, [(7, 1)], 8), # 128 # TDNN - time context + (2, 3, 1, [(3, 3)], 8), # 256 + ], + group_divisor=1, + out_channels=512, + ): + super().__init__() + self.F = F + self.C = C + + self.block_1d_type = block_1d_type + self.block_2d_type = block_2d_type + + self.stages_setup = stages_setup + self.build(stages_setup, group_divisor, out_channels) + + def build(self, stages_setup, group_divisor, out_channels): + self.num_stages = len(stages_setup) + + cur_c = self.C + cur_f = self.F + # Weighting the inputs + # TODO: ask authors about the impact of this pre-weighting + self.inputs_weights = torch.nn.ParameterList( + [nn.Parameter(torch.ones(1, 1, 1, 1), requires_grad=False)] + + [ + nn.Parameter( + torch.zeros(1, num_inputs + 1, self.C * self.F, 1), + requires_grad=True, + ) + for num_inputs in range(1, len(stages_setup) + 1) + ] + ) + + self.stem = nn.Sequential( + nn.Conv2d(1, int(cur_c), kernel_size=3, stride=1, padding="same"), + LayerNorm(int(cur_c), eps=1e-6, data_format="channels_first"), + ) + + Block1d = functools.partial(TimeContextBlock1d, block_type=self.block_1d_type) + Block2d = functools.partial(ConvBlock2d, block_type=self.block_2d_type) + + self.stages_cfs = [] + for stage_ind, ( + stride, + num_blocks, + conv_exp, + kernel_sizes, # TODO: Why the kernel_sizes are not used? + att_block_red, + ) in enumerate(stages_setup): + assert stride in [1, 2, 3] + # Pool frequencies & expand channels if needed + layers = [ + nn.Conv2d( + int(cur_c), + int(stride * cur_c * conv_exp), + kernel_size=(stride, 1), + stride=(stride, 1), + padding=0, + groups=1, + ), + ] + + self.stages_cfs.append((cur_c, cur_f)) + + cur_c = stride * cur_c + assert cur_f % stride == 0 + cur_f = cur_f // stride + + for _ in range(num_blocks): + # ConvBlock2d(f, c, block_type="convnext_like", group_divisor=1) + layers.append( + Block2d( + c=int(cur_c * conv_exp), f=cur_f, group_divisor=group_divisor + ) + ) + + if conv_exp != 1: + # Squeeze back channels to align with ReDimNet c+f reshaping: + _group_divisor = group_divisor + # if c // group_divisor == 0: + # _group_divisor = c + layers.append( + nn.Sequential( + nn.Conv2d( + int(cur_c * conv_exp), + cur_c, + kernel_size=(3, 3), + stride=1, + padding="same", + groups=( + cur_c // _group_divisor + if _group_divisor is not None + else 1 + ), + ), + nn.BatchNorm2d( + cur_c, + eps=1e-6, + ), + nn.GELU(), + nn.Conv2d(cur_c, cur_c, 1), + ) + ) + + layers.append(to1d()) + + # reduce block? + if att_block_red is not None: + layers.append( + Block1d(self.C * self.F, hC=(self.C * self.F) // att_block_red) + ) + + setattr(self, f"stage{stage_ind}", nn.Sequential(*layers)) + + if out_channels is not None: + self.mfa = nn.Sequential( + nn.Conv1d(self.F * self.C, out_channels, kernel_size=1, padding="same"), + nn.BatchNorm1d(out_channels, affine=True), + ) + else: + self.mfa = nn.Identity() + + def to1d(self, x): + size = x.size() + bs, c, f, t = tuple(size) + return x.permute((0, 2, 1, 3)).reshape((bs, c * f, t)) + + def to2d(self, x, c, f): + size = x.size() + bs, cf, t = tuple(size) + return x.reshape((bs, f, c, t)).permute((0, 2, 1, 3)) + + def weigth1d(self, outs_1d, i): + xs = torch.cat([t.unsqueeze(1) for t in outs_1d], dim=1) + w = F.softmax(self.inputs_weights[i], dim=1) + x = (w * xs).sum(dim=1) + return x + + def run_stage(self, prev_outs_1d, stage_ind): + stage = getattr(self, f"stage{stage_ind}") + c, f = self.stages_cfs[stage_ind] + + x = self.weigth1d(prev_outs_1d, stage_ind) + x = self.to2d(x, c, f) + x = stage(x) + return x + + def forward(self, inp): + x = self.stem(inp) + outputs_1d = [self.to1d(x)] + for stage_ind in range(self.num_stages): + outputs_1d.append(self.run_stage(outputs_1d, stage_ind)) + x = self.weigth1d(outputs_1d, -1) + x = self.mfa(x) + return x + + +class ReDimNet(nn.Module): + def __init__( + self, + feat_dim=72, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + # Default setup: M version: + stages_setup=[ + # stride, num_blocks, kernel_sizes, layer_ext, att_block_red + (1, 2, 1, [(3, 3)], 12), + (2, 2, 1, [(3, 3)], 12), + (1, 3, 1, [(3, 3)], 12), + (2, 4, 1, [(3, 3)], 8), + (1, 4, 1, [(3, 3)], 8), + (2, 4, 1, [(3, 3)], 4), + ], + group_divisor=4, + out_channels=None, + # ------------------------- + embed_dim=192, + pooling_func="ASTP", + global_context_att=False, + two_emb_layer=False, + ): + + super().__init__() + self.two_emb_layer = two_emb_layer + self.backbone = ReDimNetBone( + feat_dim, + C, + block_1d_type, + block_2d_type, + stages_setup, + group_divisor, + out_channels, + ) + + if out_channels is None: + out_channels = C * feat_dim + + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=out_channels, global_context_att=global_context_att + ) + + self.pool_out_dim = self.pool.get_out_dim() + self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) + if self.two_emb_layer: + self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) + self.seg_2 = nn.Linear(embed_dim, embed_dim) + else: + self.seg_bn_1 = nn.Identity() + self.seg_2 = nn.Identity() + + def forward(self, x): + # x = self.spec(x).unsqueeze(1) + x = x.permute(0, 2, 1) # (B,F,T) => (B,T,F) + x = x.unsqueeze_(1) + out = self.backbone(x) + + stats = self.pool(out) + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_a, embed_b + else: + return torch.tensor(0.0), embed_a + + +def ReDimNetB0(feat_dim=60, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=10, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=[ + (1, 2, 1, [(3, 3)], 30), + (2, 3, 2, [(3, 3)], 30), + (1, 3, 3, [(3, 3)], 30), + (2, 4, 2, [(3, 3)], 10), + (1, 3, 1, [(3, 3)], 10), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB1(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=12, + block_1d_type="conv+att", + block_2d_type="convnext_like", + stages_setup=[ + (1, 2, 1, [(3, 3)], None), + (2, 3, 1, [(3, 3)], None), + (3, 4, 1, [(3, 3)], 12), + (2, 5, 1, [(3, 3)], 12), + (2, 3, 1, [(3, 3)], 8), + ], + group_divisor=8, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=False, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB2(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=16, + block_1d_type="conv+att", + block_2d_type="convnext_like", + stages_setup=[ + (1, 2, 1, [(3, 3)], 12), + (2, 2, 1, [(3, 3)], 12), + (1, 3, 1, [(3, 3)], 12), + (2, 4, 1, [(3, 3)], 8), + (1, 4, 1, [(3, 3)], 8), + (2, 4, 1, [(3, 3)], 4), + ], + group_divisor=4, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB3(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=16, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 6, 4, [(3, 3)], 32), + (2, 6, 2, [(3, 3)], 32), + (1, 8, 2, [(3, 3)], 32), + (2, 10, 2, [(3, 3)], 16), + (1, 10, 1, [(3, 3)], 16), + (2, 8, 1, [(3, 3)], 16), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB4(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 4, 2, [(3, 3)], 48), + (2, 4, 2, [(3, 3)], 48), + (1, 6, 2, [(3, 3)], 48), + (2, 6, 1, [(3, 3)], 32), + (1, 8, 1, [(3, 3)], 24), + (2, 4, 1, [(3, 3)], 16), + ], + group_divisor=1, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB5(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet_fwse", + stages_setup=[ + (1, 4, 2, [(3, 3)], 48), + (2, 4, 2, [(3, 3)], 48), + (1, 6, 2, [(3, 3)], 48), + (2, 6, 1, [(3, 3)], 32), + (1, 8, 1, [(3, 3)], 24), + (2, 4, 1, [(3, 3)], 16), + ], + group_divisor=16, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=False): + return ReDimNet( + feat_dim=feat_dim, + C=32, + block_1d_type="conv+att", + block_2d_type="basic_resnet", + stages_setup=[ + (1, 4, 4, [(3, 3)], 32), + (2, 6, 2, [(3, 3)], 32), + (1, 6, 2, [(3, 3)], 24), + (3, 8, 1, [(3, 3)], 24), + (1, 8, 1, [(3, 3)], 16), + (2, 8, 1, [(3, 3)], 16), + ], + group_divisor=32, + out_channels=None, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + two_emb_layer=two_emb_layer, + ) + + +if __name__ == "__main__": + # x = torch.zeros(1, 200, 72) + # model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False) + # model.eval() + # out = model(x) + # print(out[-1].size()) + + # num_params = sum(p.numel() for p in model.parameters()) + # print("{} M".format(num_params / 1e6)) + + # Currently, the model sizes are not exactly the same with the ones in the paper + model_classes = [ + ReDimNetB0, # 1.0M v.s. 1.0M + ReDimNetB1, # 1.9M v.s. 2.2M + ReDimNetB2, # 4.9M v.s. 4.7M + ReDimNetB3, # 3.2M v.s. 3.0M + ReDimNetB4, # 6.4M v.s. 6.3M + ReDimNetB5, # 7.65M v.s. 9.2M + ReDimNetB6, # 15.0M v.s. 15.0M + ] + + for i, model_class in enumerate(model_classes): + model = model_class() + num_params = sum(p.numel() for p in model.parameters()) + print("{} M of Model B{}".format(num_params / 1e6, i)) From 3139604e3301464b9408dcefe1b928384c647962 Mon Sep 17 00:00:00 2001 From: Shuai Date: Wed, 7 Aug 2024 15:07:10 +0800 Subject: [PATCH 2/6] fix lint errors --- wespeaker/models/redimnet.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index 5f0a45c5..f4a38672 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -20,13 +20,14 @@ Cite: @misc{yakovlev2024reshapedimensionsnetworkspeaker, - title={Reshape Dimensions Network for Speaker Recognition}, - author={Ivan Yakovlev and Rostislav Makarov and Andrei Balykin and Pavel Malov and Anton Okhotnikov and Nikita Torgashov}, + title={Reshape Dimensions Network for Speaker Recognition}, + author={Ivan Yakovlev and Rostislav Makarov and Andrei Balykin + and Pavel Malov and Anton Okhotnikov and Nikita Torgashov}, year={2024}, eprint={2407.18223}, archivePrefix={arXiv}, primaryClass={eess.AS}, - url={https://arxiv.org/abs/2407.18223}, + url={https://arxiv.org/abs/2407.18223}, } """ import math @@ -210,11 +211,6 @@ def forward(self, x): return self.conv_block(x) -# ------------------------------------------------------------- -# Copy multi-head attention module from hugginface wav2vec2 -# ------------------------------------------------------------- -# Copied from https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/models/wav2vec2/modeling_wav2vec2.py -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 class MultiHeadAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -233,8 +229,8 @@ def __init__( if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." + f"embed_dim must be divisible by num_heads (got " + f"`embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 @@ -275,7 +271,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # Use the `embed_dim` from the config (stored in the class) + # rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) @@ -1020,7 +1017,7 @@ def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa # num_params = sum(p.numel() for p in model.parameters()) # print("{} M".format(num_params / 1e6)) - # Currently, the model sizes are not exactly the same with the ones in the paper + # Currently, the model sizes differ from the ones in the paper model_classes = [ ReDimNetB0, # 1.0M v.s. 1.0M ReDimNetB1, # 1.9M v.s. 2.2M From 9b6536e052304f9be395640b65cc290bea42ea16 Mon Sep 17 00:00:00 2001 From: Shuai Date: Wed, 7 Aug 2024 15:25:34 +0800 Subject: [PATCH 3/6] fix B1 config --- wespeaker/models/redimnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index f4a38672..82970a5d 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -799,7 +799,7 @@ def __init__( # ------------------------- embed_dim=192, pooling_func="ASTP", - global_context_att=False, + global_context_att=True, two_emb_layer=False, ): @@ -887,7 +887,7 @@ def ReDimNetB1(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa out_channels=None, embed_dim=embed_dim, pooling_func=pooling_func, - global_context_att=False, + global_context_att=True, two_emb_layer=two_emb_layer, ) @@ -1020,7 +1020,7 @@ def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa # Currently, the model sizes differ from the ones in the paper model_classes = [ ReDimNetB0, # 1.0M v.s. 1.0M - ReDimNetB1, # 1.9M v.s. 2.2M + ReDimNetB1, # 2.1M v.s. 2.2M ReDimNetB2, # 4.9M v.s. 4.7M ReDimNetB3, # 3.2M v.s. 3.0M ReDimNetB4, # 6.4M v.s. 6.3M From c347bd03c49d61a55456377ee531aab976cf1fb9 Mon Sep 17 00:00:00 2001 From: Shuai Date: Wed, 7 Aug 2024 19:07:09 +0800 Subject: [PATCH 4/6] fix the speaker_encoder import --- wespeaker/models/redimnet.py | 16 ++++++++-------- wespeaker/models/speaker_model.py | 3 +++ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index 82970a5d..ae4ca3b9 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -1008,14 +1008,14 @@ def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa if __name__ == "__main__": - # x = torch.zeros(1, 200, 72) - # model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False) - # model.eval() - # out = model(x) - # print(out[-1].size()) - - # num_params = sum(p.numel() for p in model.parameters()) - # print("{} M".format(num_params / 1e6)) + x = torch.zeros(1, 200, 72) + model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False) + model.eval() + out = model(x) + print(out[-1].size()) + + num_params = sum(p.numel() for p in model.parameters()) + print("{} M".format(num_params / 1e6)) # Currently, the model sizes differ from the ones in the paper model_classes = [ diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 8475f1ae..4ef4a311 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -20,6 +20,7 @@ import wespeaker.models.eres2net as eres2net import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net +import wespeaker.models.redimnet as redimnet def get_speaker_model(model_name: str): @@ -39,6 +40,8 @@ def get_speaker_model(model_name: str): return getattr(res2net, model_name) elif model_name.startswith("Gemini"): return getattr(gemini, model_name) + elif model_name.startswith("ReDimNet"): + return getattr(redimnet, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1) From 7100c4c66003a5d7f52c64f56c1a0573948b643c Mon Sep 17 00:00:00 2001 From: wsstriving Date: Tue, 27 Aug 2024 12:00:08 +0800 Subject: [PATCH 5/6] fix the lint error --- wespeaker/models/redimnet.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/wespeaker/models/redimnet.py b/wespeaker/models/redimnet.py index ae4ca3b9..ecd7db17 100644 --- a/wespeaker/models/redimnet.py +++ b/wespeaker/models/redimnet.py @@ -141,9 +141,7 @@ def __init__( self, C, dim=2, - kernel_sizes=[ - (3, 3), - ], + kernel_sizes=((3, 3),), group_divisor=1, padding="same", ): @@ -612,7 +610,7 @@ def __init__( C=16, block_1d_type="conv+att", block_2d_type="basic_resnet", - stages_setup=[ + stages_setup=( # stride, num_blocks, conv_exp, kernel_size, att_block_red (1, 2, 1, [(3, 3)], None), # 16 (2, 3, 1, [(3, 3)], None), # 32 @@ -621,7 +619,7 @@ def __init__( (2, 5, 1, [(3, 3)], 8), # 128 (1, 5, 1, [(7, 1)], 8), # 128 # TDNN - time context (2, 3, 1, [(3, 3)], 8), # 256 - ], + ), group_divisor=1, out_channels=512, ): @@ -785,7 +783,7 @@ def __init__( block_1d_type="conv+att", block_2d_type="basic_resnet", # Default setup: M version: - stages_setup=[ + stages_setup=( # stride, num_blocks, kernel_sizes, layer_ext, att_block_red (1, 2, 1, [(3, 3)], 12), (2, 2, 1, [(3, 3)], 12), @@ -793,7 +791,7 @@ def __init__( (2, 4, 1, [(3, 3)], 8), (1, 4, 1, [(3, 3)], 8), (2, 4, 1, [(3, 3)], 4), - ], + ), group_divisor=4, out_channels=None, # ------------------------- From 28f97639f909c6cd4528c747cd5168bafeda0852 Mon Sep 17 00:00:00 2001 From: BingHan <879090429@qq.com> Date: Tue, 27 Aug 2024 15:38:22 +0800 Subject: [PATCH 6/6] Update redimnet.yaml change to arc_margin --- examples/voxceleb/v2/conf/redimnet.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/voxceleb/v2/conf/redimnet.yaml b/examples/voxceleb/v2/conf/redimnet.yaml index c392b3de..6e8fb295 100644 --- a/examples/voxceleb/v2/conf/redimnet.yaml +++ b/examples/voxceleb/v2/conf/redimnet.yaml @@ -53,7 +53,7 @@ model_args: projection_args: - project_type: "sphereface2" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter + project_type: "arc_margin" # add_margin, arc_margin, sphere, sphereface2, softmax, arc_margin_intertopk_subcenter scale: 32.0 easy_margin: False