-
Notifications
You must be signed in to change notification settings - Fork 24
Generic NextGenHDF writer and a jobf for rasr caches dump #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
import h5py | ||
import numpy as np | ||
from typing import Dict, Optional | ||
import sys | ||
import logging, sys, shutil, tempfile | ||
|
||
from typing import Dict, List, Optional | ||
|
||
|
||
def get_input_dict_from_returnn_hdf(hdf_file: h5py.File) -> Dict[str, np.ndarray]: | ||
|
@@ -35,3 +36,95 @@ def get_returnn_simple_hdf_writer(returnn_root: Optional[str]): | |
from returnn.datasets.hdf import SimpleHDFWriter | ||
|
||
return SimpleHDFWriter | ||
|
||
|
||
class NextGenHDFWriter: | ||
""" | ||
This class is a helper for writing the of returnn NextGenHDFDataset | ||
""" | ||
|
||
def __init__( | ||
self, | ||
filename: str, | ||
label_info_dict: Dict, | ||
feature_names: Optional[List[str]] = None, | ||
label_data_type: type = np.uint16, | ||
label_parser_name: str = "sparse", | ||
feature_parser_name: str = "feature_sequence", | ||
): | ||
""" | ||
:param label_info_dict: a dictionay with the label targets used in returnn training as key and numebr of label classes as value | ||
:param feature_names: additional feature data names | ||
:param label_data_type: type that is used to store the data | ||
:param label_parser_name: this should be checked against returnn implementations | ||
"param feature_parser_name: as above | ||
""" | ||
self.label_info_dict = label_info_dict | ||
self.label_parser_name = label_parser_name | ||
self.feature_names = feature_names | ||
if feature_names is not None: | ||
self.feature_parser_name = feature_parser_name | ||
self.label_data_type = label_data_type | ||
self.string_data_type = h5py.special_dtype(vlen=str) | ||
self.sequence_names = [] | ||
self.group_holder_dict = {} | ||
|
||
self.file_init() | ||
|
||
def file_init(self): | ||
self.temp_file = tempfile.NamedTemporaryFile(suffix="_NextGenHDFWriter_outHDF") | ||
self.temp_path = self.temp_file.name | ||
self.out_hdf = h5py.File(self.temp_path, "w") | ||
|
||
logging.info(f"processing temporary file { self.temp_path}") | ||
|
||
# root | ||
self.root_group = self.out_hdf.create_group("streams") | ||
|
||
for label_name, label_dim in self.label_info_dict.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following the init restructuring from above, I think it would be better to just loop here over all streams and have a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sill, I do not know how to cleanly deal with the "label_dim" argument in that case... |
||
self.group_holder_dict[label_name] = self._get_label_group(label_name, label_dim) | ||
|
||
if self.feature_names is not None: | ||
for feat_name in self.feature_names: | ||
self.group_holder_dict[feat_name] = self._get_feature_group(feat_name) | ||
|
||
def _get_label_group(self, label_name, label_dim): | ||
assert label_dim > 0, "you should have at least dim 1" | ||
label_group = self.root_group.create_group(label_name) | ||
label_group.attrs["parser"] = "sparse" | ||
label_group.create_dataset( | ||
"feature_names", | ||
data=[b"label_%d" % l for l in range(label_dim)], | ||
dtype=self.string_data_type, | ||
) | ||
|
||
return label_group.create_group("data") | ||
|
||
def _get_feature_group(self, feature_name): | ||
feature_group = self.root_group.create_group(feature_name) | ||
feature_group.attrs["parser"] = self.feature_parser_name | ||
|
||
return feature_group.create_group("data") | ||
|
||
def add_sequence_name(self, seq_name): | ||
self.sequence_names.append(seq_name) | ||
|
||
def add_data_to_group(self, group_name, seq_name, data): | ||
if group_name in self.label_info_dict: | ||
data = np.array(data).astype(self.label_data_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be left to the user to be defined from the outside. Then the |
||
|
||
# the / in the string would lead to more hierarchies automatically, thus substitute | ||
self.group_holder_dict[group_name].create_dataset(seq_name.replace("/", "\\"), data=data) | ||
|
||
def finalize(self, filename): | ||
seq_name_set = set([s.replace("/", "\\") for s in self.sequence_names]) | ||
|
||
for k, group in self.group_holder_dict.items(): | ||
assert set(group.keys()) == seq_name_set, "The sequence names do not match between groups" | ||
|
||
self.out_hdf.create_dataset( | ||
"seq_names", data=[s.encode() for s in self.sequence_names], dtype=self.string_data_type | ||
) | ||
|
||
self.out_hdf.close() | ||
shutil.move(self.temp_path, filename) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
__all__ = ["ReturnnDumpHDFJob", "ReturnnRasrDumpHDFJob", "BlissToPcmHDFJob", "RasrAlignmentDumpHDFJob"] | ||
__all__ = [ | ||
"ReturnnDumpHDFJob", | ||
"ReturnnRasrDumpHDFJob", | ||
"BlissToPcmHDFJob", | ||
"RasrAlignmentDumpHDFJob", | ||
"RasrDumpNextGenHDFJob", | ||
] | ||
|
||
from dataclasses import dataclass | ||
import glob | ||
|
@@ -8,11 +14,11 @@ | |
import soundfile as sf | ||
import subprocess as sp | ||
import tempfile | ||
from typing import List, Optional | ||
from typing import Dict, List, Optional | ||
|
||
from .rasr_training import ReturnnRasrTrainingJob | ||
from i6_core.lib import corpus | ||
from i6_core.lib.hdf import get_returnn_simple_hdf_writer | ||
from i6_core.lib.hdf import get_returnn_simple_hdf_writer, NextGenHDFWriter | ||
from i6_core.lib.rasr_cache import FileArchive | ||
import i6_core.rasr as rasr | ||
from i6_core.util import instanciate_delayed, uopen, write_paths_to_file | ||
|
@@ -371,3 +377,135 @@ def run(self, task_id): | |
|
||
if len(excluded_segments): | ||
write_paths_to_file(f"excluded_segments.{task_id}", excluded_segments) | ||
|
||
|
||
class RasrDumpNextGenHDFJob(Job): | ||
""" | ||
This Job reads Rasr alignment and feature caches and dump them in hdf files for NextGenHDFDataset class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
alignment_caches_dict: Dict[str, List[tk.Path]], | ||
allophones: [tk.Path, Dict[str, tk.Path]], | ||
state_tyings: [tk.Path, Dict[str, tk.Path]], | ||
reference_target: str, | ||
data_type: type = np.uint16, | ||
feature_caches_dict: Optional[Dict[str, List[tk.Path]]] = None, | ||
): | ||
""" | ||
:param alignment_caches_dict: the dict keys are the target strings used in returnn training, values are output of an AlignmentJob | ||
:param allophones: e.g. output of a StoreAllophonesJob or a dict as above with same keys as alignment | ||
:param state_tyings: e.g. output of a DumpStateTyingJob or a dict as above with same keys as alignment | ||
:param reference_target: is one of the keys of alignment_caches that would be taken as reference for reading segments | ||
:param data_type: type that is used to store the data | ||
:param returnn_root: file path to the RETURNN repository root folder | ||
:param feature_caches_dict: similar to the alignment_caches_dict just for features | ||
""" | ||
self.alignment_caches_dict = alignment_caches_dict | ||
self.feature_caches_dict = feature_caches_dict | ||
self.allophones = allophones | ||
self.state_tyings = state_tyings | ||
self.reference_target = reference_target | ||
self.out_hdf_files = [ | ||
self.output_path(f"data.hdf.{d}") for d in range(len(self.alignment_caches_dict[reference_target])) | ||
] | ||
self.out_excluded_segments = self.output_path(f"excluded.segments") | ||
self.data_type = data_type | ||
self.rqmt = {"cpu": 1, "mem": 8, "time": 0.5} | ||
|
||
def tasks(self): | ||
yield Task("run", rqmt=self.rqmt, args=range(1, (len(self.out_hdf_files) + 1))) | ||
yield Task("merge", mini_task=True) | ||
|
||
def _get_state_tying(self, state_tying_file): | ||
return dict((k, int(v)) for l in open(state_tying_file.get_path()) for k, v in [l.strip().split()[0:2]]) | ||
|
||
def _get_alignment_cache(self, task_id, alignment_name, allophones): | ||
alignment_cache = FileArchive(self.alignment_caches_dict[alignment_name][task_id - 1].get_path()) | ||
_ = alignment_cache.setAllophones(allophones.get_path()) | ||
|
||
return alignment_cache | ||
|
||
def _get_label_sequence(self, alignment_cache, file, state_tying): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it makes sense to move such fuctions to the FileArchive class itself aat some point. |
||
targets = [] | ||
alignment = alignment_cache.read(file, "align") | ||
if not len(alignment): | ||
return None | ||
alignmentStates = ["%s.%d" % (alignment_cache.allophones[t[1]], t[2]) for t in alignment] | ||
for allophone in alignmentStates: | ||
targets.append(state_tying[allophone]) | ||
data = np.array(targets).astype(np.dtype(self.data_type)) | ||
|
||
return data | ||
|
||
def merge(self): | ||
excluded_segments = [] | ||
excluded_files = glob.glob("excluded_segments.*") | ||
for p in excluded_files: | ||
if os.path.isfile(p): | ||
with open(p, "r") as f: | ||
segments = f.read().splitlines() | ||
excluded_segments.extend(segments) | ||
|
||
write_paths_to_file(self.out_excluded_segments, excluded_segments) | ||
|
||
def run(self, task_id): | ||
# this is first used to initialize the writer and then to contain the caches | ||
alignment_dict = dict.fromkeys(self.alignment_caches_dict.keys(), None) | ||
feature_names = list(self.feature_caches_dict.keys()) | ||
|
||
assert ( | ||
self.reference_target in alignment_dict | ||
), "you did not define a proper target for reference alignment cache" | ||
|
||
allophones = {} | ||
state_tyings = {} | ||
for k in alignment_dict.keys(): | ||
allophones[k] = self.allophones if not isinstance(self.allophones, dict) else self.allophones[k] | ||
state_tying_path = self.state_tyings if not isinstance(self.state_tyings, dict) else self.state_tyings[k] | ||
state_tyings[k] = self._get_state_tying(state_tying_path) | ||
alignment_dict[k] = state_tyings[k][max(state_tyings[k])] + 1 # max label class id + 1 | ||
|
||
hdf_writer = NextGenHDFWriter( | ||
filename=f"hdf.{task_id - 1}", | ||
label_info_dict=alignment_dict, | ||
feature_names=feature_names, | ||
label_data_type=self.data_type, | ||
) | ||
|
||
for k in alignment_dict.keys(): | ||
alignment_dict[k] = self._get_alignment_cache(task_id, k, allophones[k]) | ||
|
||
feature_dict = dict( | ||
zip( | ||
feature_names, [FileArchive(self.feature_caches_dict[n][task_id - 1].get_path()) for n in feature_names] | ||
) | ||
) | ||
|
||
excluded_segments = [] | ||
|
||
for file in alignment_dict[self.reference_target].ft: | ||
info = alignment_dict[self.reference_target].ft[file] | ||
if info.name.endswith(".attribs"): | ||
continue | ||
seq_name = info.name | ||
|
||
for align_k in alignment_dict.keys(): | ||
label_seq = self._get_label_sequence(alignment_dict[align_k], file, state_tyings[align_k]) | ||
if label_seq is None: | ||
if seq_name not in excluded_segments: | ||
excluded_segments.append(seq_name) | ||
continue | ||
hdf_writer.add_data_to_group(align_k, seq_name, label_seq) | ||
|
||
for feat_key in feature_dict.keys(): | ||
times, features = feature_dict[feat_key].read(file, "feat") | ||
hdf_writer.add_data_to_group(feat_key, seq_name, features) | ||
|
||
hdf_writer.add_sequence_name(seq_name) | ||
|
||
hdf_writer.finalize(self.out_hdf_files[task_id - 1]) | ||
|
||
if len(excluded_segments): | ||
write_paths_to_file(f"excluded_segments.{task_id}", excluded_segments) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arguments are non-generic, you are assuming there is a feature and a label stream, and for the label stream you even require this exists. This means in its current form the Writer would be unusable for e.g. my TTS setups, or at least very unintuitive in the handling. Following your design, I would imagine something like:
unfortunately, this does not work for the reason that the sparse stream needs to know the extra size information. This is why I thought explicit groups with their own init params might be nicer... I am little out of ideas here...