Skip to content

Commit

Permalink
Datasets, remove some usages of num_output (#1601)
Browse files Browse the repository at this point in the history
* CombinedDataset: do not depend on num_outputs for dtype inference
* OggZipDataset: define `get_data_dim` to avoid reliance on `num_outputs`
* LmDataset: do not rely on num_outputs for data keys and data dim
  • Loading branch information
NeoLegends authored Oct 16, 2024
1 parent 61705b1 commit e095e05
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
15 changes: 15 additions & 0 deletions returnn/datasets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,21 @@ def get_total_num_seqs(self, *, fast: bool = False) -> int:
self._lazy_init()
return len(self._data)

def get_data_dim(self, key: str) -> int:
""":return: dim of data entry with `key`"""
if key == "data":
assert self.feature_extractor is not None
return self.feature_extractor.get_feature_dimension()
elif key == "classes":
assert self.targets is not None
return self.targets.num_labels
elif key == "raw":
return 0
elif key == "orth":
return 256
else:
raise ValueError(f"{self}: unknown data key: {key}")

def get_data_dtype(self, key: str) -> str:
""":return: dtype of data entry with `key`"""
if key == "data":
Expand Down
19 changes: 17 additions & 2 deletions returnn/datasets/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Optional, Union, Callable, List, Tuple, BinaryIO, cast
from typing import Optional, Union, Callable, Iterator, List, Tuple, BinaryIO, cast
import typing
import os
import sys
Expand Down Expand Up @@ -394,7 +394,16 @@ def get_data_keys(self):
"""
:rtype: list[str]
"""
return sorted(self.num_outputs.keys())

def _data_keys() -> Iterator[str]:
# return keys in alphabetically sorted
if self.add_delayed_seq_data:
yield "delayed"
yield "data"
for i in range(self.add_random_phone_seqs):
yield f"random{i}"

return list(_data_keys())

def get_target_list(self):
"""
Expand Down Expand Up @@ -1626,6 +1635,12 @@ def is_data_sparse(self, key):
"""
return True # all is sparse

def get_data_dim(self, _key: str) -> int:
"""
:return: the data dim of data entry `_key`
"""
return self.num_inputs # same dim for all keys

def get_data_dtype(self, key):
"""
:param str key:
Expand Down
22 changes: 13 additions & 9 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,8 @@ def __init__(
self.num_outputs = self.data_dims

self.data_dtypes = {
data_key: _select_dtype(data_key, self.data_dims, data_dtypes) for data_key in self.data_keys
data_key: _select_dtype(data_key, dset_data_key, self.datasets[dset_key], data_dtypes)
for (dset_key, dset_data_key), data_key in data_map.items()
}

self.dataset_seq_idx_boundaries = None # type: typing.Optional[typing.List[int]]
Expand Down Expand Up @@ -2086,13 +2087,16 @@ def is_data_sparse(self, key: str) -> bool:
return self._data_keys[key].get("sparse", False)


def _select_dtype(key, data_dims, data_dtypes):
if data_dtypes and key in data_dtypes:
v = data_dtypes[key]
def _select_dtype(
data_key: str,
dataset_data_key: str,
dataset: Dataset,
existing_dtype_map: Optional[Dict[str, str]],
):
if existing_dtype_map and data_key in existing_dtype_map:
v = existing_dtype_map[data_key]
assert isinstance(v, str) # e.g. "int32" or "float32"
return v
assert key in data_dims
if data_dims[key][1] == 1: # sparse
return "int32" # standard for 1-of-k
else:
return "float32" # standard otherwise
if dataset.is_data_sparse(dataset_data_key):
return "int32" # standard 1-of-k
return "float32" # default float32 otherwise

0 comments on commit e095e05

Please sign in to comment.