-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathoracle_separate.py
129 lines (119 loc) · 3.85 KB
/
oracle_separate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
# coding=utf-8
# wujian@2018
import argparse
import os
import numpy as np
from dataset import SpectrogramReader
from utils import istft
def compute_mask(mixture, targets_list, mask_type):
"""
Arguments:
mixture: STFT of mixture signal(complex result)
targets_list: python list of target signal's STFT results(complex result)
mask_type: ["irm", "ibm", "iam", "psm"]
"""
if mask_type == 'ibm':
max_index = np.argmax(
np.stack([np.abs(mat) for mat in targets_list]), 0)
return [max_index == s for s in range(len(targets_list))]
if mask_type == "irm":
denominator = sum([np.abs(mat) for mat in targets_list])
else:
denominator = np.abs(mixture)
if mask_type != "psm":
masks = [np.abs(mat) / denominator for mat in targets_list]
else:
mixture_phase = np.angle(mixture)
masks = [
np.abs(mat) * np.cos(mixture_phase - np.angle(mat)) / denominator
for mat in targets_list
]
return masks
def run(args):
# return complex result
reader_kwargs = {
"frame_length": args.frame_length,
"frame_shift": args.frame_shift,
"window": args.window,
"center": True
}
print("Using Mask: {}".format(args.mask.upper()))
mixture_reader = SpectrogramReader(
args.mix_scp, **reader_kwargs, return_samps=True)
targets_reader = [
SpectrogramReader(scp, **reader_kwargs) for scp in args.ref_scp
]
num_utts = 0
for key, packed in mixture_reader:
samps, mixture = packed
norm = np.linalg.norm(samps, np.inf)
skip = False
for reader in targets_reader:
if key not in reader:
print("Skip utterance {}, missing targets".format(key))
skip = True
break
if skip:
continue
num_utts += 1
if not num_utts % 1000:
print("Processed {} utterance...".format(num_utts))
targets_list = [reader[key] for reader in targets_reader]
spk_masks = compute_mask(mixture, targets_list, args.mask)
for index, mask in enumerate(spk_masks):
istft(
os.path.join(args.dump_dir, '{}.spk{}.wav'.format(
key, index + 1)),
mixture * mask,
**reader_kwargs,
norm=norm,
fs=8000,
nsamps=samps.size)
print("Processed {} utterance!".format(num_utts))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=
"Command to do oracle speech separation, using specified mask(IAM|IBM|IRM|PSM)"
)
parser.add_argument(
"mix_scp",
type=str,
help="Location of mixture wave scripts in kaldi format")
parser.add_argument(
"ref_scp",
nargs="+",
help="List of reference speaker wave scripts in kaldi format")
parser.add_argument(
"--dump-dir",
type=str,
default="cache",
dest="dump_dir",
help="Location to dump seperated speakers")
parser.add_argument(
"--frame-shift",
type=int,
default=128,
dest="frame_shift",
help="Number of samples shifted when spliting frames")
parser.add_argument(
"--frame-length",
type=int,
default=256,
dest="frame_length",
help="Frame length in number of samples")
parser.add_argument(
"--window",
type=str,
default="hann",
dest="window",
help="Type of window function, see scipy.signal.get_window")
parser.add_argument(
"--mask",
type=str,
dest="mask",
default="irm",
choices=["iam", 'irm', 'ibm', 'psm'],
help="Type of mask to use for speech separation")
args = parser.parse_args()
run(args)