Skip to content

Commit

Permalink
[dataset] support speaker in dataset (#2292)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Jan 11, 2024
1 parent c0f4194 commit 3f8f043
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
4 changes: 4 additions & 0 deletions wenet/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def Dataset(data_type,
else:
dataset = Processor(dataset, processor.parse_raw)

speaker_conf = conf.get('speaker_conf', None)
if speaker_conf is not None:
dataset = Processor(dataset, processor.parse_speaker, **speaker_conf)

dataset = Processor(dataset, processor.tokenize, tokenizer)
filter_conf = conf.get('filter_conf', {})
dataset = Processor(dataset, processor.filter, **filter_conf)
Expand Down
29 changes: 23 additions & 6 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import librosa
import logging
import json
Expand Down Expand Up @@ -145,15 +146,27 @@ def parse_raw(data):
frame_offset=start_frame)
else:
waveform, sample_rate = torchaudio.load(wav_file)
example = dict(key=key,
txt=txt,
wav=waveform,
sample_rate=sample_rate)
example = copy.deepcopy(obj) # copy and keep all the fields
example['wav'] = waveform # overwrite wav
example['sample_rate'] = sample_rate
yield example
except Exception as ex:
logging.warning('Failed to read {}'.format(wav_file))


def parse_speaker(data, speaker_table_path):
speaker_dict = {}
with open(speaker_table_path, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
speaker_dict[arr[0]] = int(arr[1])
for sample in data:
assert 'speaker' in sample
speaker = sample['speaker']
sample['speaker'] = speaker_dict.get(speaker, 0)
yield sample


def filter(data,
max_length=10240,
min_length=10,
Expand Down Expand Up @@ -628,8 +641,7 @@ def padding(data):
padded_wavs = pad_sequence(sorted_wavs,
batch_first=True,
padding_value=0)

yield {
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"target": padding_labels,
Expand All @@ -638,3 +650,8 @@ def padding(data):
"pcm": padded_wavs,
"pcm_length": wav_lengths,
}
if 'speaker' in sample[0]:
speaker = torch.tensor([sample[i]['speaker'] for i in order],
dtype=torch.int32)
batch['speaker'] = speaker
yield batch

0 comments on commit 3f8f043

Please sign in to comment.