Skip to content

added cosine loss to dino #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions egs/voxceleb/ssl.v1/conf/train_fwseresnet34_dino_v1.2.2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
data:
train:
dataset:
teacher_aug_cfg: conf/teacher_reverb_noise_aug.yaml
student_aug_cfg: conf/reverb_noise_aug.yaml
student_chunk_length: 2.
teacher_chunk_length: 4.
num_teacher_chunks: 2
num_student_chunks: 4
same_teacher_student_chunks: false
sampler:
sampler_type: seg_chunk_sampler
min_batch_size: 16
max_chunk_length: 12.0
min_chunk_length: 6.0
data_loader:
num_workers: 8
val:
dataset:
teacher_aug_cfg: conf/teacher_reverb_noise_aug.yaml
student_aug_cfg: conf/reverb_noise_aug.yaml
student_chunk_length: 2.
teacher_chunk_length: 4.
num_teacher_chunks: 2
num_student_chunks: 4
same_teacher_student_chunks: false
sampler:
sampler_type: seg_chunk_sampler
min_batch_size: 16
max_chunk_length: 12.0
min_chunk_length: 6.0
data_loader:
num_workers: 8
student_model:
feats: fbank80_specaug1_stmn_16k.yaml
xvector:
resnet_type: fwseresnet34
in_channels: 1
in_feats: 80
in_kernel_size: 3
in_stride: 1
no_maxpool: true
pool_net:
pool_type: ch-wise-att-mean+stddev
inner_feats: 128
dropout_rate: 0.01
norm_before: false
hid_act: swish
se_r: 4
head_type: dino
embed_dim: 192
num_embed_layers: 3
loss_type: softmax
head_use_norm: true
head_hid_dim: 768
head_bottleneck_dim: 192
proj_head_use_norm: true
proj_head_norm_before: false
teacher_model:
xvector:
override_dropouts: true
dropout_rate: 0.0
dino_loss:
num_classes: 65536
temp_warmup_epochs: 0
teacher_temp: 0.04
cosine_loss:
warmup_epochs: 20
scale: 0.1
trainer:
optim:
opt_type: adamw
lr: 0.005
amsgrad: false
beta1: 0.9
beta2: 0.99
weight_decay: 1e-1
lrsched:
lrsch_type: exp_lr
decay_rate: 0.5
decay_steps: 60000
hold_steps: 15000
min_lr: 1.0e-05
warmup_steps: 15000
update_lr_on_opt_step: true
teacher_optim:
init_momentum: 0.996
momentum: 1.0
warmup_steps: 500000
grad_clip: 15
use_amp: true
log_interval: 1000
epochs: 100
eff_batch_size: 256
train_mode: full
freeze_output_layer_steps: 1500
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# ECAPA-TDNN 512x3

# acoustic features
feat_config=conf/fbank80_stmn_16k.yaml
feat_type=fbank80_stmn

#vad
vad_config=conf/vad_16k.yaml

# x-vector training
nnet_data=voxceleb2cat_train

# x-vector cfg
nnet_type=resnet
nnet_name=${feat_type}_fwseresnet34_dino.v1.2.2

nnet_s1_base_cfg=conf/train_fwseresnet34_dino_v1.2.2.yaml
nnet_s1_name=$nnet_name.s1
nnet_s1_dir=exp/xvector_nnets/$nnet_s1_name
nnet_s1=$nnet_s1_dir/teacher_model_ep0034.pth
nnet_s1=$nnet_s1_dir/teacher_model_ep0025.pth

# clustering of dino embeddings
cluster_method=cos_ahc_plda_ahc
cluster_cfg=conf/cluster_lresnet34_v1.2_cos_ahc_plda_ahc.yaml
cluster_name=${cluster_method}
cluster_dir=exp/clustering/$nnet_s1_name/$cluster_name

# plda
plda_cfg=conf/plda.yaml

# finetuning stage 1.1
nnet_ft_s1_1_base_cfg=conf/train_lresnet34_xvec_stage1.1_v1.2.yaml
nnet_ft_s1_1_name=$nnet_name.s1.ft.s1.1
nnet_ft_s1_1_dir=exp/xvector_nnets/$nnet_ft_s1_1_name
nnet_ft_s1_1=$nnet_ft_s1_1_dir/model_ep0030.pth

# finetuning stage 1.2
nnet_ft_s1_2_base_cfg=conf/train_lresnet34_xvec_stage1.2_v1.2.yaml
nnet_ft_s1_2_name=$nnet_name.s1.ft.s1.2
nnet_ft_s1_2_dir=exp/xvector_nnets/$nnet_ft_s1_2_name
nnet_ft_s1_2=$nnet_ft_s1_2_dir/model_ep0070.pth

# clustering of ft embeddings from stage 1.2
cluster_ft_s1_method=cos_ahc
cluster_ft_s1_cfg=conf/cluster_lresnet34_v1.2_ft1_cos_ahc.yaml
cluster_ft_s1_name=${cluster_method_ft_s1_method}
cluster_ft_s1_dir=exp/clustering/$nnet_ft_s1_2_name/$cluster_ft_s1_name

# finetuning stage 2.1
nnet_ft_s2_1_base_cfg=conf/train_lresnet34_xvec_stage1.1_v1.2.yaml
nnet_ft_s2_1_name=$nnet_name.s1.ft.s2.1
nnet_ft_s2_1_dir=exp/xvector_nnets/$nnet_ft_s2_1_name
nnet_ft_s2_1=$nnet_ft_s2_1_dir/model_ep0030.pth

# finetuning stage 2.2
nnet_ft_s2_2_base_cfg=conf/train_lresnet34_xvec_stage1.2_v1.2.yaml
nnet_ft_s2_2_name=$nnet_name.s1.ft.s2.2
nnet_ft_s2_2_dir=exp/xvector_nnets/$nnet_ft_s2_2_name
nnet_ft_s2_2=$nnet_ft_s2_2_dir/model_ep0070.pth

# clustering of ft embeddings from stage 2.2
cluster_ft_s2_method=cos_ahc
cluster_ft_s2_cfg=conf/cluster_lresnet34_v1.2_ft1_cos_ahc.yaml
cluster_ft_s2_name=${cluster_method_ft_s2_method}
cluster_ft_s2_dir=exp/clustering/$nnet_ft_s2_2_name/$cluster_ft_s2_name

20 changes: 19 additions & 1 deletion hyperion/bin/train_dino_wav2xvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hyperion.hyp_defs import config_logger, set_float_cpu
from hyperion.torch.data import DINOAudioDataset as AD
from hyperion.torch.data import SegSamplerFactory
from hyperion.torch.losses import DINOLoss
from hyperion.torch.losses import CosineDINOLoss, DINOLoss
from hyperion.torch.metrics import CategoricalAccuracy

# from hyperion.torch.models import EfficientNetXVector as EXVec
Expand Down Expand Up @@ -109,6 +109,21 @@ def init_dino_loss(rank, **kwargs):
return loss


def init_cosine_loss(rank, **kwargs):
loss_args = kwargs["cosine_loss"]
if rank == 0:
logging.info(f"cosine loss args={loss_args}")

if loss_args["scale"] <= 0:
return None

loss = CosineDINOLoss(**loss_args)
if rank == 0:
logging.info(f"cosine-loss={loss}")

return loss


def train_xvec(gpu_id, args):
config_logger(args.verbose)
del args.verbose
Expand All @@ -126,6 +141,7 @@ def train_xvec(gpu_id, args):
val_loader = init_data(partition="val", **kwargs)

dino_loss = init_dino_loss(**kwargs)
cosine_loss = init_cosine_loss(**kwargs)
student_model = init_student_xvector(num_classes=dino_loss.num_classes, **kwargs)
kwargs["student_model"] = student_model
teacher_model = init_teacher_xvector(**kwargs)
Expand All @@ -138,6 +154,7 @@ def train_xvec(gpu_id, args):
student_model,
teacher_model,
dino_loss,
cosine_loss=cosine_loss,
device=device,
metrics=metrics,
ddp=world_size > 1,
Expand Down Expand Up @@ -185,6 +202,7 @@ def make_parser(xvec_class):
xvec_class.add_class_args(parser, prefix="student_model")
xvec_class.add_dino_teacher_args(parser, prefix="teacher_model")
DINOLoss.add_class_args(parser, prefix="dino_loss")
CosineDINOLoss.add_class_args(parser, prefix="cosine_loss")
Trainer.add_class_args(
parser, prefix="trainer", train_modes=xvec_class.valid_train_modes()
)
Expand Down
2 changes: 1 addition & 1 deletion hyperion/torch/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
"""

from .bce_with_llr import BCEWithLLR
from .dino_loss import DINOLoss
from .dino_loss import CosineDINOLoss, DINOLoss
81 changes: 81 additions & 0 deletions hyperion/torch/losses/dino_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Copyright 2023 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""

import logging

import torch
Expand Down Expand Up @@ -162,3 +163,83 @@ def add_class_args(parser, prefix=None):

if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))


class CosineDINOLoss(nn.Module):
"""Cosine Loss to regularize DINO
and enforze DINO embeddings to be suitable for cosine scoring

"""

def __init__(
self,
scale: float = 1.0,
warmup_epochs: int = 30,
):
super().__init__()
self.scale = scale
self.warmup_epochs = warmup_epochs
self.cur_scale = scale

def update_scale(self, epoch: int):
if epoch < self.warmup_epochs:
self.cur_scale = self.scale * epoch / self.warmup_epochs
logging.info("updating cosine-loss scale=%.3f", self.cur_scale)
else:
self.cur_scale = self.scale

def forward(
self,
student_embed: torch.Tensor,
teacher_embed: torch.Tensor,
num_student_crops: int,
num_teacher_crops: int,
):
"""
Cosine scoring between embeddings of the teacher and student networks.
"""
if self.scale == 0:
return 0

student_embed = torch.nn.functional.normalize(student_embed, dim=-1)
teacher_embed = torch.nn.functional.normalize(teacher_embed, dim=-1)
student_embed = student_embed.chunk(num_student_crops)
teacher_embed = teacher_embed.detach()
teacher_embed = teacher_embed.chunk(num_teacher_crops)

total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_embed):
for ip, p in enumerate(student_embed):
if ip == iq and num_teacher_crops > 1:
# we skip cases where student and teacher operate on the same view
continue
loss = 1 - torch.sum(q * p, dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms

return self.cur_scale * total_loss, total_loss

@staticmethod
def filter_args(**kwargs):
return filter_func_args(CosineDINOLoss.__init__, kwargs)

@staticmethod
def add_class_args(parser, prefix=None):
if prefix is not None:
outer_parser = parser
parser = ArgumentParser(prog="")

parser.add_argument(
"--scale", default=0, type=float, help="Scale of Cosine loss to reg. DINO"
)
parser.add_argument(
"--warmup-epochs",
default=30,
type=int,
help="warmup epochs for the scale",
)

if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
Loading