Skip to content

Commit

Permalink
[ssl] support pack dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Aug 14, 2024
1 parent 0dac0b5 commit e33da5f
Showing 1 changed file with 56 additions and 1 deletion.
57 changes: 56 additions & 1 deletion wenet/ssl/init_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
from collections.abc import Callable
from functools import partial
import sys
from typing import List

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterDataPipe, functional_datapipe
from wenet.dataset import processor
from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource
from wenet.dataset.datapipes import (WenetRawDatasetSource,
WenetTarShardDatasetSource)


@functional_datapipe("pack_speech")
class PackSpeechDatapipe(IterDataPipe):

def __init__(
self,
dataset: IterDataPipe,
length_fn: Callable,
merge_speech_fn: Callable,
max_length: int = 30000,
) -> None:
super().__init__()
self.dp = dataset
self._iter = None
self.length_fn = length_fn
self.max_length = max_length
self.merge_fn = merge_speech_fn
self.buf = []

self.length = 0

def __iter__(self):
for elem in self.dp:
current_length = self.length + self.length_fn(elem)
if current_length >= self.max_length:
long_elem = self.merge_fn(self.buf)
yield long_elem
del self.buf
self.buf = []
self.buf.append(elem)
if len(self.buf) > 0:
yield self.merge_fn(self.buf)
del self.buf


def cat_speech(buffer: List):
assert len(buffer) > 0
waves = [sample['wav'] for sample in buffer]
sample_rate = waves[0]['sample_rate']
wav = torch.cat(waves, dim=1)
return {"wav": wav, "sample_rate": sample_rate}


def padding(data):
Expand Down Expand Up @@ -39,6 +85,11 @@ def padding(data):
return batch


def wav_length_fn(sample):
wav = sample['wav']
return wav.size(1)


def Dataset(data_type, data_list_file, conf=None, partition=True):
""" Construct dataset from arguments for ssl model
Expand Down Expand Up @@ -81,6 +132,10 @@ def Dataset(data_type, data_list_file, conf=None, partition=True):
dataset = dataset.map(
partial(processor.singal_channel, **singal_channel_conf))

pack_conf = conf.get('pack_conf', {})
if pack_conf:
dataset = dataset.pack_speech(wav_length_fn, cat_speech,
pack_conf['max_speech_length'])
filter_conf = conf.get('filter_conf', {})
dataset = dataset.filter(partial(processor.filter, **filter_conf))

Expand Down

0 comments on commit e33da5f

Please sign in to comment.