-
Notifications
You must be signed in to change notification settings - Fork 0
/
segment.py
90 lines (77 loc) · 2.45 KB
/
segment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
from pathlib import Path
from tqdm import tqdm
from functools import partial
from multiprocessing import Pool
import torch
import numpy as np
def process_file(paths, codebook, segment, gamma):
in_path, out_path = paths
sequence = np.load(in_path)
codes, boundaries = segment(sequence, codebook, gamma)
np.savez(out_path.with_suffix(".npz"), codes=codes, boundaries=boundaries)
return sequence.shape[0], np.mean(np.diff(boundaries))
def segment_dataset(args):
kmeans, segment = torch.hub.load(
"bshall/dusted:main", "kmeans", language=args.language, trust_repo=True
)
in_paths = list(args.in_dir.rglob("*.npy"))
out_paths = [args.out_dir / path.relative_to(args.in_dir) for path in in_paths]
segment_file = partial(
process_file,
codebook=kmeans.cluster_centers_,
segment=segment,
gamma=args.gamma,
)
for path in tqdm(out_paths):
path.parent.mkdir(exist_ok=True, parents=True)
print("Segmenting dataset...")
with Pool(processes=args.processes) as pool:
results = [
result
for result in tqdm(
pool.imap(
segment_file,
zip(in_paths, out_paths),
),
total=len(in_paths),
)
]
frames, boundary_length = zip(*results)
print(f"Segmented {sum(frames) * 0.02 / 60 / 60:.2f} hours of audio")
print(f"Average segment length: {np.mean(boundary_length) * 0.02:.2f} seconds")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Segment an audio dataset into phone-like units."
)
parser.add_argument(
"in_dir",
metavar="in-dir",
type=Path,
help="path to the speech features.",
)
parser.add_argument(
"out_dir",
metavar="out-dir",
type=Path,
help="path to the output directory.",
)
parser.add_argument(
"language",
choices=["english", "chinese", "french"],
help="pre-training language of the HuBERT content encoder.",
)
parser.add_argument(
"--gamma",
default=0.2,
type=float,
help="regularization weight for segmentation (defaults to 0.2).",
)
parser.add_argument(
"--processes",
type=int,
help="number of processes (defaults to 10).",
default=10,
)
args = parser.parse_args()
segment_dataset(args)