Skip to content

Commit 6833401

Browse files
Merge pull request #142 from hyperion-ml/persephone-refactor
added cosine loss to dino
2 parents 5960ec6 + dcabebe commit 6833401

13 files changed

+481
-66
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
data:
2+
train:
3+
dataset:
4+
teacher_aug_cfg: conf/teacher_reverb_noise_aug.yaml
5+
student_aug_cfg: conf/reverb_noise_aug.yaml
6+
student_chunk_length: 2.
7+
teacher_chunk_length: 4.
8+
num_teacher_chunks: 2
9+
num_student_chunks: 4
10+
same_teacher_student_chunks: false
11+
sampler:
12+
sampler_type: seg_chunk_sampler
13+
min_batch_size: 16
14+
max_chunk_length: 12.0
15+
min_chunk_length: 6.0
16+
data_loader:
17+
num_workers: 8
18+
val:
19+
dataset:
20+
teacher_aug_cfg: conf/teacher_reverb_noise_aug.yaml
21+
student_aug_cfg: conf/reverb_noise_aug.yaml
22+
student_chunk_length: 2.
23+
teacher_chunk_length: 4.
24+
num_teacher_chunks: 2
25+
num_student_chunks: 4
26+
same_teacher_student_chunks: false
27+
sampler:
28+
sampler_type: seg_chunk_sampler
29+
min_batch_size: 16
30+
max_chunk_length: 12.0
31+
min_chunk_length: 6.0
32+
data_loader:
33+
num_workers: 8
34+
student_model:
35+
feats: fbank80_specaug1_stmn_16k.yaml
36+
xvector:
37+
resnet_type: fwseresnet34
38+
in_channels: 1
39+
in_feats: 80
40+
in_kernel_size: 3
41+
in_stride: 1
42+
no_maxpool: true
43+
pool_net:
44+
pool_type: ch-wise-att-mean+stddev
45+
inner_feats: 128
46+
dropout_rate: 0.01
47+
norm_before: false
48+
hid_act: swish
49+
se_r: 4
50+
head_type: dino
51+
embed_dim: 192
52+
num_embed_layers: 3
53+
loss_type: softmax
54+
head_use_norm: true
55+
head_hid_dim: 768
56+
head_bottleneck_dim: 192
57+
proj_head_use_norm: true
58+
proj_head_norm_before: false
59+
teacher_model:
60+
xvector:
61+
override_dropouts: true
62+
dropout_rate: 0.0
63+
dino_loss:
64+
num_classes: 65536
65+
temp_warmup_epochs: 0
66+
teacher_temp: 0.04
67+
cosine_loss:
68+
warmup_epochs: 20
69+
scale: 0.1
70+
trainer:
71+
optim:
72+
opt_type: adamw
73+
lr: 0.005
74+
amsgrad: false
75+
beta1: 0.9
76+
beta2: 0.99
77+
weight_decay: 1e-1
78+
lrsched:
79+
lrsch_type: exp_lr
80+
decay_rate: 0.5
81+
decay_steps: 60000
82+
hold_steps: 15000
83+
min_lr: 1.0e-05
84+
warmup_steps: 15000
85+
update_lr_on_opt_step: true
86+
teacher_optim:
87+
init_momentum: 0.996
88+
momentum: 1.0
89+
warmup_steps: 500000
90+
grad_clip: 15
91+
use_amp: true
92+
log_interval: 1000
93+
epochs: 100
94+
eff_batch_size: 256
95+
train_mode: full
96+
freeze_output_layer_steps: 1500
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ECAPA-TDNN 512x3
2+
3+
# acoustic features
4+
feat_config=conf/fbank80_stmn_16k.yaml
5+
feat_type=fbank80_stmn
6+
7+
#vad
8+
vad_config=conf/vad_16k.yaml
9+
10+
# x-vector training
11+
nnet_data=voxceleb2cat_train
12+
13+
# x-vector cfg
14+
nnet_type=resnet
15+
nnet_name=${feat_type}_fwseresnet34_dino.v1.2.2
16+
17+
nnet_s1_base_cfg=conf/train_fwseresnet34_dino_v1.2.2.yaml
18+
nnet_s1_name=$nnet_name.s1
19+
nnet_s1_dir=exp/xvector_nnets/$nnet_s1_name
20+
nnet_s1=$nnet_s1_dir/teacher_model_ep0034.pth
21+
nnet_s1=$nnet_s1_dir/teacher_model_ep0025.pth
22+
23+
# clustering of dino embeddings
24+
cluster_method=cos_ahc_plda_ahc
25+
cluster_cfg=conf/cluster_lresnet34_v1.2_cos_ahc_plda_ahc.yaml
26+
cluster_name=${cluster_method}
27+
cluster_dir=exp/clustering/$nnet_s1_name/$cluster_name
28+
29+
# plda
30+
plda_cfg=conf/plda.yaml
31+
32+
# finetuning stage 1.1
33+
nnet_ft_s1_1_base_cfg=conf/train_lresnet34_xvec_stage1.1_v1.2.yaml
34+
nnet_ft_s1_1_name=$nnet_name.s1.ft.s1.1
35+
nnet_ft_s1_1_dir=exp/xvector_nnets/$nnet_ft_s1_1_name
36+
nnet_ft_s1_1=$nnet_ft_s1_1_dir/model_ep0030.pth
37+
38+
# finetuning stage 1.2
39+
nnet_ft_s1_2_base_cfg=conf/train_lresnet34_xvec_stage1.2_v1.2.yaml
40+
nnet_ft_s1_2_name=$nnet_name.s1.ft.s1.2
41+
nnet_ft_s1_2_dir=exp/xvector_nnets/$nnet_ft_s1_2_name
42+
nnet_ft_s1_2=$nnet_ft_s1_2_dir/model_ep0070.pth
43+
44+
# clustering of ft embeddings from stage 1.2
45+
cluster_ft_s1_method=cos_ahc
46+
cluster_ft_s1_cfg=conf/cluster_lresnet34_v1.2_ft1_cos_ahc.yaml
47+
cluster_ft_s1_name=${cluster_method_ft_s1_method}
48+
cluster_ft_s1_dir=exp/clustering/$nnet_ft_s1_2_name/$cluster_ft_s1_name
49+
50+
# finetuning stage 2.1
51+
nnet_ft_s2_1_base_cfg=conf/train_lresnet34_xvec_stage1.1_v1.2.yaml
52+
nnet_ft_s2_1_name=$nnet_name.s1.ft.s2.1
53+
nnet_ft_s2_1_dir=exp/xvector_nnets/$nnet_ft_s2_1_name
54+
nnet_ft_s2_1=$nnet_ft_s2_1_dir/model_ep0030.pth
55+
56+
# finetuning stage 2.2
57+
nnet_ft_s2_2_base_cfg=conf/train_lresnet34_xvec_stage1.2_v1.2.yaml
58+
nnet_ft_s2_2_name=$nnet_name.s1.ft.s2.2
59+
nnet_ft_s2_2_dir=exp/xvector_nnets/$nnet_ft_s2_2_name
60+
nnet_ft_s2_2=$nnet_ft_s2_2_dir/model_ep0070.pth
61+
62+
# clustering of ft embeddings from stage 2.2
63+
cluster_ft_s2_method=cos_ahc
64+
cluster_ft_s2_cfg=conf/cluster_lresnet34_v1.2_ft1_cos_ahc.yaml
65+
cluster_ft_s2_name=${cluster_method_ft_s2_method}
66+
cluster_ft_s2_dir=exp/clustering/$nnet_ft_s2_2_name/$cluster_ft_s2_name
67+

hyperion/bin/train_dino_wav2xvector.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from hyperion.hyp_defs import config_logger, set_float_cpu
2020
from hyperion.torch.data import DINOAudioDataset as AD
2121
from hyperion.torch.data import SegSamplerFactory
22-
from hyperion.torch.losses import DINOLoss
22+
from hyperion.torch.losses import CosineDINOLoss, DINOLoss
2323
from hyperion.torch.metrics import CategoricalAccuracy
2424

2525
# from hyperion.torch.models import EfficientNetXVector as EXVec
@@ -109,6 +109,21 @@ def init_dino_loss(rank, **kwargs):
109109
return loss
110110

111111

112+
def init_cosine_loss(rank, **kwargs):
113+
loss_args = kwargs["cosine_loss"]
114+
if rank == 0:
115+
logging.info(f"cosine loss args={loss_args}")
116+
117+
if loss_args["scale"] <= 0:
118+
return None
119+
120+
loss = CosineDINOLoss(**loss_args)
121+
if rank == 0:
122+
logging.info(f"cosine-loss={loss}")
123+
124+
return loss
125+
126+
112127
def train_xvec(gpu_id, args):
113128
config_logger(args.verbose)
114129
del args.verbose
@@ -126,6 +141,7 @@ def train_xvec(gpu_id, args):
126141
val_loader = init_data(partition="val", **kwargs)
127142

128143
dino_loss = init_dino_loss(**kwargs)
144+
cosine_loss = init_cosine_loss(**kwargs)
129145
student_model = init_student_xvector(num_classes=dino_loss.num_classes, **kwargs)
130146
kwargs["student_model"] = student_model
131147
teacher_model = init_teacher_xvector(**kwargs)
@@ -138,6 +154,7 @@ def train_xvec(gpu_id, args):
138154
student_model,
139155
teacher_model,
140156
dino_loss,
157+
cosine_loss=cosine_loss,
141158
device=device,
142159
metrics=metrics,
143160
ddp=world_size > 1,
@@ -185,6 +202,7 @@ def make_parser(xvec_class):
185202
xvec_class.add_class_args(parser, prefix="student_model")
186203
xvec_class.add_dino_teacher_args(parser, prefix="teacher_model")
187204
DINOLoss.add_class_args(parser, prefix="dino_loss")
205+
CosineDINOLoss.add_class_args(parser, prefix="cosine_loss")
188206
Trainer.add_class_args(
189207
parser, prefix="trainer", train_modes=xvec_class.valid_train_modes()
190208
)

hyperion/torch/losses/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
"""
55

66
from .bce_with_llr import BCEWithLLR
7-
from .dino_loss import DINOLoss
7+
from .dino_loss import CosineDINOLoss, DINOLoss

hyperion/torch/losses/dino_loss.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Copyright 2023 Johns Hopkins University (Author: Jesus Villalba)
33
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
44
"""
5+
56
import logging
67

78
import torch
@@ -162,3 +163,83 @@ def add_class_args(parser, prefix=None):
162163

163164
if prefix is not None:
164165
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
166+
167+
168+
class CosineDINOLoss(nn.Module):
169+
"""Cosine Loss to regularize DINO
170+
and enforze DINO embeddings to be suitable for cosine scoring
171+
172+
"""
173+
174+
def __init__(
175+
self,
176+
scale: float = 1.0,
177+
warmup_epochs: int = 30,
178+
):
179+
super().__init__()
180+
self.scale = scale
181+
self.warmup_epochs = warmup_epochs
182+
self.cur_scale = scale
183+
184+
def update_scale(self, epoch: int):
185+
if epoch < self.warmup_epochs:
186+
self.cur_scale = self.scale * epoch / self.warmup_epochs
187+
logging.info("updating cosine-loss scale=%.3f", self.cur_scale)
188+
else:
189+
self.cur_scale = self.scale
190+
191+
def forward(
192+
self,
193+
student_embed: torch.Tensor,
194+
teacher_embed: torch.Tensor,
195+
num_student_crops: int,
196+
num_teacher_crops: int,
197+
):
198+
"""
199+
Cosine scoring between embeddings of the teacher and student networks.
200+
"""
201+
if self.scale == 0:
202+
return 0
203+
204+
student_embed = torch.nn.functional.normalize(student_embed, dim=-1)
205+
teacher_embed = torch.nn.functional.normalize(teacher_embed, dim=-1)
206+
student_embed = student_embed.chunk(num_student_crops)
207+
teacher_embed = teacher_embed.detach()
208+
teacher_embed = teacher_embed.chunk(num_teacher_crops)
209+
210+
total_loss = 0
211+
n_loss_terms = 0
212+
for iq, q in enumerate(teacher_embed):
213+
for ip, p in enumerate(student_embed):
214+
if ip == iq and num_teacher_crops > 1:
215+
# we skip cases where student and teacher operate on the same view
216+
continue
217+
loss = 1 - torch.sum(q * p, dim=-1)
218+
total_loss += loss.mean()
219+
n_loss_terms += 1
220+
total_loss /= n_loss_terms
221+
222+
return self.cur_scale * total_loss, total_loss
223+
224+
@staticmethod
225+
def filter_args(**kwargs):
226+
return filter_func_args(CosineDINOLoss.__init__, kwargs)
227+
228+
@staticmethod
229+
def add_class_args(parser, prefix=None):
230+
if prefix is not None:
231+
outer_parser = parser
232+
parser = ArgumentParser(prog="")
233+
234+
parser.add_argument(
235+
"--scale", default=0, type=float, help="Scale of Cosine loss to reg. DINO"
236+
)
237+
parser.add_argument(
238+
"--warmup-epochs",
239+
default=30,
240+
type=int,
241+
help="warmup epochs for the scale",
242+
)
243+
244+
if prefix is not None:
245+
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))

0 commit comments

Comments
 (0)