Skip to content

Commit 2cf461d

Browse files
Merge branch 'persephone-refactor' of https://github.com/hyperion-ml/hyperion into persephone-asr
addding changes in persephone-refactor to persephone-asr
2 parents 02bb457 + 3bfb7f0 commit 2cf461d

File tree

8 files changed

+53
-67
lines changed

8 files changed

+53
-67
lines changed

hyperion/bin/extract_wav2vec2xvectors.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020

2121
import torch
22+
import torchaudio.transforms as tat
2223

2324
from hyperion.hyp_defs import config_logger, float_cpu, set_float_cpu
2425
from hyperion.utils import Utt2Info
@@ -30,6 +31,25 @@
3031
from hyperion.torch.utils import open_device
3132
from hyperion.torch import TorchModelLoader as TML
3233

34+
resamplers = {}
35+
36+
37+
def get_resampler(source_fs, target_fs):
38+
if source_fs in resamplers:
39+
return resamplers[source_fs]
40+
41+
resampler = tat.Resample(
42+
int(source_fs),
43+
int(target_fs),
44+
lowpass_filter_width=64,
45+
rolloff=0.9475937167399596,
46+
resampling_method="kaiser_window",
47+
beta=14.769656459379492,
48+
)
49+
resampler_f = lambda x: resampler(torch.from_numpy(x)).numpy()
50+
resamplers[source_fs] = resampler_f
51+
return resampler_f
52+
3353

3454
def init_device(use_gpu):
3555
set_float_cpu("float32")
@@ -102,7 +122,7 @@ def extract_xvectors(
102122
num_augs,
103123
aug_info_path,
104124
use_gpu,
105-
**kwargs
125+
**kwargs,
106126
):
107127

108128
rng = np.random.RandomState(seed=1123581321 + kwargs["part_idx"])
@@ -122,12 +142,11 @@ def extract_xvectors(
122142
num_augs = 1
123143

124144
ar_args = AR.filter_args(**kwargs)
145+
ar_args["wav_scale"] = 1.0
125146
logging.info("opening output stream: %s", output_spec)
126147
with DWF.create(output_spec, scp_sep=scp_sep) as writer:
127148

128-
logging.info(
129-
"opening input stream: {} with args={}".format(input_spec, ar_args)
130-
)
149+
logging.info(f"opening input stream: {input_spec} with args={ar_args}")
131150
with AR(input_spec, **ar_args) as reader:
132151

133152
if vad_spec is not None:
@@ -146,6 +165,11 @@ def extract_xvectors(
146165
key0 = key[0]
147166
fs = fs[0]
148167
t2 = time.time()
168+
if fs != model.sample_frequency:
169+
resampler = get_resampler(fs, model.sample_frequency)
170+
print(f"x01 {x0.shape} {np.max(x0)}")
171+
x0 = resampler(x0)
172+
print(f"x01 {x0.shape} {np.max(x0)}")
149173

150174
logging.info("processing utt %s", key0)
151175
for aug_id in range(num_augs):

hyperion/bin/finetune_wav2vec2xvector.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from hyperion.torch.utils import ddp
2626
from hyperion.torch.trainers import XVectorTrainer as Trainer
2727
from hyperion.torch.data import AudioDataset as AD
28-
from hyperion.torch.data import ClassWeightedSeqSampler as Sampler
28+
from hyperion.torch.data import SegSamplerFactory
29+
2930
from hyperion.torch.metrics import CategoricalAccuracy
3031
from hyperion.torch.models import (
3132
HFWav2Vec2ResNet1dXVector,
@@ -45,19 +46,21 @@ def init_data(partition, rank, num_gpus, **kwargs):
4546

4647
kwargs = kwargs["data"][partition]
4748
ad_args = AD.filter_args(**kwargs["dataset"])
48-
sampler_args = Sampler.filter_args(**kwargs["sampler"])
49+
sampler_args = kwargs["sampler"]
4950
if rank == 0:
5051
logging.info("{} audio dataset args={}".format(partition, ad_args))
5152
logging.info("{} sampler args={}".format(partition, sampler_args))
5253
logging.info("init %s dataset", partition)
5354

54-
ad_args["is_val"] = partition == "val"
55+
is_val = partition == "val"
56+
ad_args["is_val"] = is_val
57+
sampler_args["shuffle"] = not is_val
5558
dataset = AD(**ad_args)
5659

5760
if rank == 0:
5861
logging.info("init %s samplers", partition)
5962

60-
sampler = Sampler(dataset, **sampler_args)
63+
sampler = SegSamplerFactory.create(dataset, **sampler_args)
6164

6265
if rank == 0:
6366
logging.info("init %s dataloader", partition)
@@ -71,18 +74,6 @@ def init_data(partition, rank, num_gpus, **kwargs):
7174
return data_loader
7275

7376

74-
# def init_model(num_classes, in_model_file, rank, **kwargs):
75-
# xvec_args = kwargs["model"]["xvector"]
76-
# if rank == 0:
77-
# logging.info("xvector network ft args={}".format(xvec_args))
78-
# xvec_args["num_classes"] = num_classes
79-
# model = TML.load(in_model_file)
80-
# model.rebuild_output_layer(**xvec_args)
81-
# if rank == 0:
82-
# logging.info("model={}".format(model))
83-
# return model
84-
85-
8677
def init_model(num_classes, in_model_file, rank, **kwargs):
8778
model_args = kwargs["model"]
8879
if rank == 0:
@@ -127,19 +118,15 @@ def train_model(gpu_id, args):
127118

128119
train_loader = init_data(partition="train", **kwargs)
129120
val_loader = init_data(partition="val", **kwargs)
130-
model = init_model(train_loader.dataset.num_classes, **kwargs)
121+
model = init_model(list(train_loader.dataset.num_classes.values())[0], **kwargs)
131122
init_hard_prototype_mining(model, train_loader, val_loader, rank)
132123

133124
trn_args = Trainer.filter_args(**kwargs["trainer"])
134125
if rank == 0:
135126
logging.info("trainer args={}".format(trn_args))
136127
metrics = {"acc": CategoricalAccuracy()}
137128
trainer = Trainer(
138-
model,
139-
device=device,
140-
metrics=metrics,
141-
ddp=world_size > 1,
142-
**trn_args,
129+
model, device=device, metrics=metrics, ddp=world_size > 1, **trn_args,
143130
)
144131
trainer.load_last_checkpoint()
145132
trainer.fit(train_loader, val_loader)
@@ -153,7 +140,7 @@ def make_parser(model_class):
153140
parser.add_argument("--cfg", action=ActionConfigFile)
154141
train_parser = ArgumentParser(prog="")
155142
AD.add_class_args(train_parser, prefix="dataset", skip={})
156-
Sampler.add_class_args(train_parser, prefix="sampler")
143+
SegSamplerFactory.add_class_args(train_parser, prefix="sampler")
157144
train_parser.add_argument(
158145
"--data_loader.num-workers",
159146
type=int,
@@ -163,7 +150,7 @@ def make_parser(model_class):
163150

164151
val_parser = ArgumentParser(prog="")
165152
AD.add_class_args(val_parser, prefix="dataset", skip={})
166-
Sampler.add_class_args(val_parser, prefix="sampler")
153+
SegSamplerFactory.add_class_args(val_parser, prefix="sampler")
167154
val_parser.add_argument(
168155
"--data_loader.num-workers",
169156
type=int,
@@ -175,14 +162,11 @@ def make_parser(model_class):
175162
data_parser.add_argument("--val", action=ActionParser(parser=val_parser))
176163
parser.add_argument("--data", action=ActionParser(parser=data_parser))
177164
parser.link_arguments(
178-
"data.train.dataset.class_file", "data.val.dataset.class_file"
165+
"data.train.dataset.class_files", "data.val.dataset.class_files"
179166
)
180167
parser.link_arguments(
181168
"data.train.data_loader.num_workers", "data.val.data_loader.num_workers"
182169
)
183-
parser.link_arguments(
184-
"data.train.sampler.batch_size", "data.val.sampler.batch_size"
185-
)
186170

187171
parser.add_argument("--in-model-file", required=True)
188172
model_class.add_finetune_args(parser, prefix="model")

hyperion/bin/finetune_xvector_from_wav.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@
2121
from hyperion.hyp_defs import config_logger, set_float_cpu
2222
from hyperion.torch.utils import ddp
2323

24-
# from hyperion.torch.models import XVector as XVec
2524
from hyperion.torch.trainers import XVectorTrainerFromWav as Trainer
2625
from hyperion.torch.data import AudioDataset as AD
2726

28-
# from hyperion.torch.data import ClassWeightedSeqSampler as Sampler
2927
from hyperion.torch import TorchModelLoader as TML
3028
from hyperion.torch.data import SegSamplerFactory
3129
from hyperion.torch.metrics import CategoricalAccuracy

hyperion/bin/train_wav2vec2xvector.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from hyperion.torch.data import AudioDataset as AD
2828
from hyperion.torch.data import SegSamplerFactory
2929

30-
# from hyperion.torch.data import ClassWeightedSeqSampler as Sampler
3130
from hyperion.torch.metrics import CategoricalAccuracy
3231
from hyperion.torch.models import (
3332
HFWav2Vec2ResNet1dXVector,
@@ -74,36 +73,6 @@ def init_data(partition, rank, num_gpus, **kwargs):
7473
return data_loader
7574

7675

77-
# def init_data(partition, rank, num_gpus, **kwargs):
78-
79-
# kwargs = kwargs["data"][partition]
80-
# ad_args = AD.filter_args(**kwargs["dataset"])
81-
# sampler_args = Sampler.filter_args(**kwargs["sampler"])
82-
# if rank == 0:
83-
# logging.info("{} audio dataset args={}".format(partition, ad_args))
84-
# logging.info("{} sampler args={}".format(partition, sampler_args))
85-
# logging.info("init %s dataset", partition)
86-
87-
# ad_args["is_val"] = partition == "val"
88-
# dataset = AD(**ad_args)
89-
90-
# if rank == 0:
91-
# logging.info("init %s samplers", partition)
92-
93-
# sampler = Sampler(dataset, **sampler_args)
94-
95-
# if rank == 0:
96-
# logging.info("init %s dataloader", partition)
97-
98-
# num_workers = kwargs["data_loader"]["num_workers"]
99-
# num_workers_per_gpu = int((num_workers + num_gpus - 1) / num_gpus)
100-
# largs = (
101-
# {"num_workers": num_workers_per_gpu, "pin_memory": True} if num_gpus > 0 else {}
102-
# )
103-
# data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, **largs)
104-
# return data_loader
105-
106-
10776
def init_model(num_classes, rank, model_class, **kwargs):
10877
model_args = model_class.filter_args(**kwargs["model"])
10978
if rank == 0:

hyperion/torch/data/audio_dataset.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,14 @@ def __getitem__(self, segment):
734734
else:
735735
r = [x]
736736

737+
# try:
738+
# import soundfile as sf
739+
740+
# for i, z in enumerate(r):
741+
# sf.write(f"file_{seg_id}.wav", z, fs, "PCM_16")
742+
# except:
743+
# print("soundfile failed", flush=True)
744+
737745
# adds the segment labels
738746
seg_info = self._get_segment_info(seg_id)
739747
r.extend(seg_info)

hyperion/torch/data/class_weighted_seg_chunk_sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def set_hard_prototypes(self, affinity_matrix):
235235
if np.all(mask_i == 0):
236236
affinity_matrix[:, i] = -1000
237237

238-
# affinity_matrix[np.diag(affinity_matrix.shape[0])] = -1.0
239238
# hard prototypes for a class are itself and k-1 closest to it.
240239
self.hard_prototypes = torch.topk(
241240
affinity_matrix, self.num_hard_prototypes, dim=-1

hyperion/torch/models/wav2xvectors/hf_wav2xvector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def _fuse_hid_feats(self, hid_feats):
8383

8484
return feats
8585

86+
@property
87+
def sample_frequency(self):
88+
return self.hf_feats.sample_frequency
89+
8690
def compute_prototype_affinity(self):
8791
return self.xvector.compute_prototype_affinity()
8892

hyperion/torch/models/xvectors/xvector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def add_finetune_args(parser, prefix=None):
892892
parser.add_argument(
893893
"--num-subcenters",
894894
default=2,
895-
type=float,
895+
type=int,
896896
help="number of subcenters in subcenter losses",
897897
)
898898

0 commit comments

Comments
 (0)