From 3f8f043736ab1289b5ce5e6d31ed5b46b663f288 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Thu, 11 Jan 2024 17:05:48 +0800 Subject: [PATCH] [dataset] support speaker in dataset (#2292) --- wenet/dataset/dataset.py | 4 ++++ wenet/dataset/processor.py | 29 +++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 9595011ec..693e0c617 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -146,6 +146,10 @@ def Dataset(data_type, else: dataset = Processor(dataset, processor.parse_raw) + speaker_conf = conf.get('speaker_conf', None) + if speaker_conf is not None: + dataset = Processor(dataset, processor.parse_speaker, **speaker_conf) + dataset = Processor(dataset, processor.tokenize, tokenizer) filter_conf = conf.get('filter_conf', {}) dataset = Processor(dataset, processor.filter, **filter_conf) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index d6bd1c166..6025c92a6 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import librosa import logging import json @@ -145,15 +146,27 @@ def parse_raw(data): frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(wav_file) - example = dict(key=key, - txt=txt, - wav=waveform, - sample_rate=sample_rate) + example = copy.deepcopy(obj) # copy and keep all the fields + example['wav'] = waveform # overwrite wav + example['sample_rate'] = sample_rate yield example except Exception as ex: logging.warning('Failed to read {}'.format(wav_file)) +def parse_speaker(data, speaker_table_path): + speaker_dict = {} + with open(speaker_table_path, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + speaker_dict[arr[0]] = int(arr[1]) + for sample in data: + assert 'speaker' in sample + speaker = sample['speaker'] + sample['speaker'] = speaker_dict.get(speaker, 0) + yield sample + + def filter(data, max_length=10240, min_length=10, @@ -628,8 +641,7 @@ def padding(data): padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0) - - yield { + batch = { "keys": sorted_keys, "feats": padded_feats, "target": padding_labels, @@ -638,3 +650,8 @@ def padding(data): "pcm": padded_wavs, "pcm_length": wav_lengths, } + if 'speaker' in sample[0]: + speaker = torch.tensor([sample[i]['speaker'] for i in order], + dtype=torch.int32) + batch['speaker'] = speaker + yield batch