Skip to content

Commit

Permalink
[add] Use shard data to calculate cmvn and extract text (#2609)
Browse files Browse the repository at this point in the history
* [add] Use shard data to calculate cmvn and extract text

* [Fix] fix code style check

* Update compute_shard_cmvn_stats.py

* Update extract_shard_data.py

改为四种提取模型:flag参数
只提取文本: text
只提取音频: audio
只计算时长: duration
提取所有内容: content

---------

Co-authored-by: lsrami <getwebshells@gmail.com>
  • Loading branch information
lsrami and lsrami authored Aug 14, 2024
1 parent 8859bf2 commit 203e067
Show file tree
Hide file tree
Showing 2 changed files with 443 additions and 0 deletions.
244 changes: 244 additions & 0 deletions tools/compute_shard_cmvn_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
#!/usr/bin/env python3

# Copyright (c) 2024 Timekettle Inc. (authors: Sirui Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import json
import yaml
import tarfile
import logging
import argparse

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import IterableDataset, DataLoader
from urllib.parse import urlparse
from subprocess import Popen, PIPE

AUDIO_FORMAT_SETS = set(["flac", "mp3", "m4a", "ogg", "opus", "wav", "wma"])


class CollateFunc(object):
"""Collate function for AudioDataset"""

def __init__(self, feat_dim, resample_rate):
self.feat_dim = feat_dim
self.resample_rate = resample_rate

def __call__(self, batch):
worker_info = torch.utils.data.get_worker_info()
if worker_info:
worker_id = worker_info.id
else:
worker_id = 0
mean_stat = torch.zeros(self.feat_dim)
var_stat = torch.zeros(self.feat_dim)
number = 0
batch_num = len(batch)
for item in batch:
try:
waveform = item["wav"]
sample_rate = item["sample_rate"]
resample_rate = sample_rate
except Exception as e:
print(f"{item} read failed")
continue
waveform = waveform * (1 << 15)
if self.resample_rate != 0 and self.resample_rate != sample_rate:
resample_rate = self.resample_rate
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate
)(waveform)

mat = kaldi.fbank(
waveform,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate,
)
mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0]
return number, mean_stat, var_stat, worker_id, batch_num


class AudioIterableDataset(IterableDataset):
def __init__(self, file_list):
self.file_list = file_list

def parse_file_list(self):
worker_info = torch.utils.data.get_worker_info()
with open(self.file_list, "r") as f:
parsed_files = [{"src": line.strip()} for line in f]

if worker_info:
# split workload
worker_id = worker_info.id
num_workers = worker_info.num_workers
parsed_files = parsed_files[worker_id::num_workers]

return parsed_files

def url_opener(self, data):
for sample in data:
assert "src" in sample
url = sample["src"]
try:
pr = urlparse(url)
# local file
if pr.scheme == "" or pr.scheme == "file":
stream = open(url, "rb")
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
else:
cmd = f"wget -q -O - {url}"
process = Popen(cmd, shell=True, stdout=PIPE)
sample.update(process=process)
stream = process.stdout
sample.update(stream=stream)
yield sample
except Exception as ex:
logging.warning("Failed to open {}".format(url))

def tar_file_and_group(self, sample):
assert "stream" in sample
stream = None
results = []
try:
stream = tarfile.open(fileobj=sample["stream"], mode="r:*")
prev_prefix = None
example = {}
valid = True
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind(".")
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1 :]
if prev_prefix is not None and prefix != prev_prefix:
example["key"] = prev_prefix
if valid:
results.append(example)
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
example["wav"] = waveform
example["sample_rate"] = sample_rate
except Exception as ex:
valid = False
logging.warning("{} error to parse {}".format(ex, name))
prev_prefix = prefix
if prev_prefix is not None:
example["key"] = prev_prefix
results.append(example)
except Exception as ex:
logging.warning(
"In tar_file_and_group: {} when processing {}".format(ex, sample["src"])
)
finally:
if stream is not None:
stream.close()
if "process" in sample:
sample["process"].communicate()
sample["stream"].close()
return results

def __iter__(self):
parsed_files = self.parse_file_list()
for sample in self.url_opener(parsed_files):
yield from self.tar_file_and_group(sample)


def main():
parser = argparse.ArgumentParser(description="extract CMVN stats")
parser.add_argument(
"--num_workers",
default=8,
type=int,
help="num of subprocess workers for processing",
)
parser.add_argument(
"--batch_size", default=16, type=int, help="num of samples in a batch"
)
parser.add_argument("--train_config", default="", help="training yaml conf")
parser.add_argument("--in_shard", default=None, help="shard data list file")
parser.add_argument("--out_cmvn", default="global_cmvn", help="global cmvn file")
parser.add_argument(
"--log_interval",
type=int,
default=1000,
help="Print log after every log_interval audios are processed.",
)
args = parser.parse_args()

with open(args.train_config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = (
configs.get("dataset_conf", {}).get("fbank_conf", {}).get("num_mel_bins", 80)
)
resample_rate = (
configs.get("dataset_conf", {}).get("resample_conf", {}).get("resample_rate", 0)
)
print(
"compute cmvn using feat_dim: {} resample rate: {}".format(
feat_dim, resample_rate
)
)
collate_func = CollateFunc(feat_dim, resample_rate)
dataset = AudioIterableDataset(args.in_shard)
batch_size = args.batch_size

data_loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=args.num_workers,
collate_fn=collate_func,
)

with torch.no_grad():
all_number = 0
all_mean_stat = torch.zeros(feat_dim)
all_var_stat = torch.zeros(feat_dim)
wav_number = 0
for i, batch in enumerate(data_loader):
number, mean_stat, var_stat, worker_id, batch_num = batch
all_mean_stat += mean_stat
all_var_stat += var_stat
all_number += number
wav_number += batch_num

if wav_number % args.log_interval == 0:
print(
f"worker_id {worker_id} processed {wav_number} wavs "
f"{all_number} frames",
file=sys.stderr,
flush=True,
)

cmvn_info = {
"mean_stat": list(all_mean_stat.tolist()),
"var_stat": list(all_var_stat.tolist()),
"frame_num": all_number,
}

with open(args.out_cmvn, "w") as fout:
fout.write(json.dumps(cmvn_info))


if __name__ == "__main__":
main()
Loading

0 comments on commit 203e067

Please sign in to comment.