Skip to content

Commit

Permalink
Restructure logic to minimize the number of file system accesses
Browse files Browse the repository at this point in the history
This also introduces a method that uses a glob to find all version folders instead of listing everything in a dir and then doing is_dir on all of them.

PiperOrigin-RevId: 694078611
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Nov 7, 2024
1 parent 978884e commit 9571ea4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 22 deletions.
77 changes: 55 additions & 22 deletions tensorflow_datasets/core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def list_data_dirs(
return sorted(d.expanduser() for d in all_data_dirs)


def get_default_data_dir(given_data_dir: str | None = None) -> Path:
def get_default_data_dir(given_data_dir: epath.PathLike | None = None) -> Path:
"""Returns the default data_dir."""
if given_data_dir:
data_dir = os.path.expanduser(given_data_dir)
Expand Down Expand Up @@ -202,29 +202,38 @@ def get_data_dir_and_dataset_dir(
dataset_dir: Dataset data directory (e.g.
`<data_dir>/<ds_name>/<config>/<version>`)
"""
all_data_dirs = list_data_dirs(given_data_dir=given_data_dir)
all_versions: set[version_lib.Version] = set()
dataset_dir_by_data_dir: dict[Path, Path] = {}
if version is not None:
version = version_lib.Version(version)

# If the data dir is explicitly given, no need to search everywhere.
if given_data_dir is not None:
given_data_dir = epath.Path(given_data_dir)
given_dataset_dir = get_dataset_dir(
data_dir=given_data_dir,
builder_name=builder_name,
config_name=config_name,
version=version,
)
return given_data_dir, given_dataset_dir

for data_dir in all_data_dirs:
# Check whether the dataset is in other data dirs.
dataset_dir_by_data_dir: dict[Path, Path] = {}
all_found_versions: set[version_lib.Version] = set()
for data_dir in list_data_dirs(given_data_dir=None):
data_dir = Path(data_dir)
# List all existing versions
dataset_config_dir = get_dataset_dir(
dataset_dir = get_dataset_dir(
data_dir=data_dir,
builder_name=builder_name,
config_name=config_name,
version=None,
)
versions = version_lib.list_all_versions(dataset_config_dir)
# Check for existence of the requested version
if version in versions:
dataset_dir_by_data_dir[data_dir] = get_dataset_dir(
data_dir=data_dir,
builder_name=builder_name,
config_name=config_name,
version=version,
)
all_versions.update(versions)
# Get all versions of the dataset in this dataset dir.
found_versions = list_dataset_versions(
dataset_config_dir=dataset_dir,
)
if version in found_versions:
dataset_dir_by_data_dir[data_dir] = dataset_dir / str(version)
all_found_versions.update(found_versions)

if len(dataset_dir_by_data_dir) > 1:
raise ValueError(
Expand All @@ -237,25 +246,25 @@ def get_data_dir_and_dataset_dir(
return next(iter(dataset_dir_by_data_dir.items()))

# No dataset found, use default directory
default_data_dir = get_default_data_dir(given_data_dir=given_data_dir)
default_data_dir = get_default_data_dir()
dataset_dir = get_dataset_dir(
data_dir=default_data_dir,
builder_name=builder_name,
config_name=config_name,
version=version,
)
if all_versions:
if all_found_versions:
logging.warning(
(
'Found a different version of the requested dataset'
' (given_data_dir=%s,dataset=%s, config=%s, version=%s):\n'
'%s\nUsing %s instead.'
' (given_data_dir=%s, dataset=%s, config=%s, version=%s):\n'
'%s\nUsing default data dir %s instead.'
),
given_data_dir,
builder_name,
config_name,
version,
'\n'.join(str(v) for v in sorted(all_versions)),
'\n'.join(str(v) for v in sorted(all_found_versions)),
dataset_dir,
)
return default_data_dir, dataset_dir
Expand Down Expand Up @@ -438,6 +447,30 @@ def _find_references_with_glob(
)


def list_dataset_versions(
dataset_config_dir: epath.PathLike,
) -> list[version_lib.Version]:
"""Returns all dataset versions found in `dataset_config_dir`.
Arguments:
dataset_config_dir: the folder that contains version subfolders.
use_gfile_glob: whether to use gfile.Glob instead of epath.Path.glob.
match_strict: whether to use gfile.Glob with match_strict=True.
"""
dataset_config_dir = epath.Path(dataset_config_dir)
glob = f'*/{constants.DATASET_INFO_FILENAME}'
found_versions: list[version_lib.Version] = []
for dataset_info in _find_files_with_glob(
dataset_config_dir,
globs=[glob],
file_names=[constants.DATASET_INFO_FILENAME],
):
version_folder = dataset_info.parent.name
if version_lib.Version.is_valid(version_folder):
found_versions.append(version_lib.Version(version_folder))
return sorted(found_versions)


def list_dataset_variants(
dataset_dir: epath.PathLike,
namespace: str | None = None,
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_datasets/core/utils/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_datasets.core import constants
from tensorflow_datasets.core import naming
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import version as version_lib

_DATA_DIR = epath.Path('/a')
_DATASET_NAME = 'my_ds'
Expand Down Expand Up @@ -162,6 +163,24 @@ def _add_features(
mock_fs.add_file(dataset_dir / constants.FEATURES_FILENAME, content=content)


def test_list_dataset_versions(mock_fs: testing.MockFs):
_add_dataset_info(mock_fs, _DATASET_DIR / '1.0.0')
_add_dataset_info(mock_fs, _DATASET_DIR / '1.0.1')
_add_dataset_info(mock_fs, _DATASET_DIR / '3.0.0')
# Does not have dataset_info.json, so ignored.
mock_fs.add_file(_DATASET_DIR / '4.0.0' / 'other_file.json')
# Version folder is inside a subfolder, so ignored.
_add_dataset_info(mock_fs, _DATASET_DIR / 'should_be_ignored' / '1.0.0')
# Subfolder name is not a valid version, so ignored.
_add_dataset_info(mock_fs, _DATASET_DIR / 'not_valid_version')
actual_versions = file_utils.list_dataset_versions(_DATASET_DIR)
assert actual_versions == [
version_lib.Version('1.0.0'),
version_lib.Version('1.0.1'),
version_lib.Version('3.0.0'),
]


def test_list_dataset_variants_with_configs(mock_fs: testing.MockFs):
configs_and_versions = {
'x': [_VERSION, '1.0.1'],
Expand Down

0 comments on commit 9571ea4

Please sign in to comment.