Skip to content

Commit

Permalink
Move extract_features to utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614628473
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Mar 11, 2024
1 parent b4a9aaa commit 950fd5e
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 95 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ jobs:
--ignore="tensorflow_datasets/import_without_tf_test.py" \
--ignore="tensorflow_datasets/core/github_api/github_path_test.py" \
--ignore="tensorflow_datasets/translate/wmt19_test.py" \
--ignore="tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py"
--ignore="tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py" \
--ignore="tensorflow_datasets/core/utils/huggingface_utils_test.py"
# Run tests without any pytest plugins. The tests should be triggered for a single shard only.
- name: Run leftover tests
Expand Down Expand Up @@ -147,7 +148,10 @@ jobs:
extras: huggingface

- name: Run HuggingFace tests
run: pytest -vv -n auto tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py
run: |
pytest -vv -n auto \
tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder_test.py \
tensorflow_datasets/core/utils/huggingface_utils_test.py
githubapi-pytest-job:
needs: activate-tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import download
from tensorflow_datasets.core import features as feature_lib
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import registered
Expand All @@ -48,49 +47,9 @@
from tensorflow_datasets.core.utils import version as version_lib
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets

_IMAGE_ENCODING_FORMAT = "png"
_EMPTY_SPLIT_WARNING_MSG = "%s split doesn't have any examples"


def extract_features(hf_features) -> feature_lib.FeatureConnector:
"""Converts Huggingface feature spec to TFDS feature spec."""
if isinstance(hf_features, (hf_datasets.Features, dict)):
return feature_lib.FeaturesDict({
name: extract_features(hf_inner_feature)
for name, hf_inner_feature in hf_features.items()
})
if isinstance(hf_features, hf_datasets.Sequence):
return feature_lib.Sequence(feature=extract_features(hf_features.feature))
if isinstance(hf_features, list):
if len(hf_features) != 1:
raise ValueError(f"List {hf_features} should have a length of 1.")
return feature_lib.Sequence(feature=extract_features(hf_features[0]))
if isinstance(hf_features, hf_datasets.Value):
return feature_lib.Scalar(
dtype=huggingface_utils.convert_to_np_dtype(hf_features.dtype)
)
if isinstance(hf_features, hf_datasets.ClassLabel):
if hf_features.names:
return feature_lib.ClassLabel(names=hf_features.names)
if hf_features.names_file:
return feature_lib.ClassLabel(names_file=hf_features.names_file)
if hf_features.num_classes:
return feature_lib.ClassLabel(num_classes=hf_features.num_classes)
if isinstance(hf_features, hf_datasets.Translation):
return feature_lib.Translation(
languages=hf_features.languages,
)
if isinstance(hf_features, hf_datasets.TranslationVariableLanguages):
return feature_lib.TranslationVariableLanguages(
languages=hf_features.languages,
)
if isinstance(hf_features, hf_datasets.Image):
return feature_lib.Image(encoding_format=_IMAGE_ENCODING_FORMAT)
if isinstance(hf_features, hf_datasets.Audio):
return feature_lib.Audio(sample_rate=hf_features.sampling_rate)
raise ValueError(f"Type {type(hf_features)} is not supported.")


def _from_tfds_to_hf(tfds_name: str) -> str:
"""Finds the original HF repo ID.
Expand Down Expand Up @@ -247,7 +206,7 @@ def _info(self) -> dataset_info_lib.DatasetInfo:
return dataset_info_lib.DatasetInfo(
builder=self,
description=self._hf_info.description,
features=extract_features(self._hf_features()),
features=huggingface_utils.convert_hf_features(self._hf_features()),
citation=self._hf_info.citation,
license=self._hf_info.license,
supervised_keys=_extract_supervised_keys(self._hf_info),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

from absl import logging
import datasets as hf_datasets
import numpy as np
import pytest
from tensorflow_datasets.core import features as feature_lib
from tensorflow_datasets.core import registered
from tensorflow_datasets.core.dataset_builders import huggingface_dataset_builder

Expand Down Expand Up @@ -61,45 +59,6 @@ def test_remove_empty_splits():
assert list(non_empty_splits["non_empty_split"]) == list(range(5))


# Encapsulate test parameters into a fixture to avoid `datasets` import during
# tests collection.
# https://docs.pytest.org/en/7.2.x/example/parametrize.html#deferring-the-setup-of-parametrized-resources
@pytest.fixture(params=["feat_dict", "audio"], name="features")
def get_features(request):
if request.param == "feat_dict":
return (
hf_datasets.Features({
"id": hf_datasets.Value("string"),
"meta": {
"left_context": hf_datasets.Value("string"),
"partial_evidence": [{
"start_id": hf_datasets.Value("int32"),
"meta": {"evidence_span": [hf_datasets.Value("string")]},
}],
},
}),
feature_lib.FeaturesDict({
"id": feature_lib.Scalar(dtype=np.str_),
"meta": feature_lib.FeaturesDict({
"left_context": feature_lib.Scalar(dtype=np.str_),
"partial_evidence": feature_lib.Sequence({
"meta": feature_lib.FeaturesDict({
"evidence_span": feature_lib.Sequence(
feature_lib.Scalar(dtype=np.str_)
),
}),
"start_id": feature_lib.Scalar(dtype=np.int32),
}),
}),
}),
)
elif request.param == "audio":
return (
hf_datasets.Audio(sampling_rate=48000),
feature_lib.Audio(sample_rate=48000),
)


@pytest.fixture(name="load_dataset_builder_mock")
def get_load_dataset_builder_mock():
with mock.patch.object(
Expand Down Expand Up @@ -154,12 +113,5 @@ def test_all_parameters_are_passed_down_to_hf(
)


def test_extract_features(features):
hf_features, tfds_features = features
assert repr(
huggingface_dataset_builder.extract_features(hf_features)
) == repr(tfds_features)


def test_hf_features(builder):
assert builder._hf_features() == {"feature": hf_datasets.Value("int32")}
54 changes: 53 additions & 1 deletion tensorflow_datasets/core/utils/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_datasets.core import features as feature_lib
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core.utils import dtype_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import datasets as hf_datasets
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf


Expand All @@ -36,9 +37,10 @@
'utf8': np.object_,
'string': np.object_,
})
_IMAGE_ENCODING_FORMAT = 'png'


def convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
def _convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
"""Returns the `np.dtype` scalar feature.
Args:
Expand All @@ -63,6 +65,56 @@ def convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
)


def convert_hf_features(hf_features) -> feature_lib.FeatureConnector:
"""Converts Huggingface feature spec to a TFDS compatible feature spec.
Args:
hf_features: Huggingface feature spec.
Returns:
The TFDS compatible feature spec.
Raises:
ValueError: If the given Huggingface features is a list with length > 1.
TypeError: If couldn't recognize the given feature spec.
"""
match hf_features:
case hf_datasets.Features() | dict():
return feature_lib.FeaturesDict({
name: convert_hf_features(hf_inner_feature)
for name, hf_inner_feature in hf_features.items()
})
case hf_datasets.Sequence():
return feature_lib.Sequence(
feature=convert_hf_features(hf_features.feature)
)
case list():
if len(hf_features) != 1:
raise ValueError(f'List {hf_features} should have a length of 1.')
return feature_lib.Sequence(feature=convert_hf_features(hf_features[0]))
case hf_datasets.Value():
return feature_lib.Scalar(dtype=_convert_to_np_dtype(hf_features.dtype))
case hf_datasets.ClassLabel():
if hf_features.names:
return feature_lib.ClassLabel(names=hf_features.names)
if hf_features.names_file:
return feature_lib.ClassLabel(names_file=hf_features.names_file)
if hf_features.num_classes:
return feature_lib.ClassLabel(num_classes=hf_features.num_classes)
case hf_datasets.Translation():
return feature_lib.Translation(languages=hf_features.languages)
case hf_datasets.TranslationVariableLanguages():
return feature_lib.TranslationVariableLanguages(
languages=hf_features.languages
)
case hf_datasets.Image():
return feature_lib.Image(encoding_format=_IMAGE_ENCODING_FORMAT)
case hf_datasets.Audio():
return feature_lib.Audio(sample_rate=hf_features.sampling_rate)

raise TypeError(f'Type {type(hf_features)} is not supported.')


def _get_default_value(
feature: feature_lib.FeatureConnector,
) -> Mapping[str, Any] | Sequence[Any] | bytes | int | float | bool:
Expand Down
62 changes: 60 additions & 2 deletions tensorflow_datasets/core/utils/huggingface_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import datetime

import datasets as hf_datasets
import numpy as np
import pytest
from tensorflow_datasets.core import features as feature_lib
Expand All @@ -25,7 +26,7 @@

def test_convert_to_np_dtype_raises():
with pytest.raises(TypeError, match='Unrecognized type.+'):
huggingface_utils.convert_to_np_dtype('I am no dtype')
huggingface_utils._convert_to_np_dtype('I am no dtype')


@pytest.mark.parametrize(
Expand All @@ -49,7 +50,64 @@ def test_convert_to_np_dtype_raises():
],
)
def test_convert_to_np_dtype(hf_dtype, np_dtype):
assert huggingface_utils.convert_to_np_dtype(hf_dtype) is np_dtype
assert huggingface_utils._convert_to_np_dtype(hf_dtype) is np_dtype


def test_convert_hf_features_raises_type_error():
with pytest.raises(TypeError, match='Type <.+> is not supported.'):
huggingface_utils.convert_hf_features('I am no features')


def test_convert_hf_features_raises_value_error():
with pytest.raises(
ValueError, match=r'List \[.+\] should have a length of 1.'
):
huggingface_utils.convert_hf_features(
[hf_datasets.Value('int32'), hf_datasets.Value('int32')]
)


@pytest.mark.parametrize(
'hf_features,tfds_features',
[
(
hf_datasets.Features(
id=hf_datasets.Value('string'),
meta={
'left_context': hf_datasets.Value('string'),
'partial_evidence': [{
'start_id': hf_datasets.Value('int32'),
'meta': {
'evidence_span': [hf_datasets.Value('string')]
},
}],
},
),
feature_lib.FeaturesDict({
'id': feature_lib.Scalar(dtype=np.str_),
'meta': feature_lib.FeaturesDict({
'left_context': feature_lib.Scalar(dtype=np.str_),
'partial_evidence': feature_lib.Sequence({
'meta': feature_lib.FeaturesDict({
'evidence_span': feature_lib.Sequence(
feature_lib.Scalar(dtype=np.str_)
),
}),
'start_id': feature_lib.Scalar(dtype=np.int32),
}),
}),
}),
),
(
hf_datasets.Audio(sampling_rate=48000),
feature_lib.Audio(sample_rate=48000),
),
],
)
def test_convert_hf_features(hf_features, tfds_features):
assert repr(huggingface_utils.convert_hf_features(hf_features)) == repr(
tfds_features
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 950fd5e

Please sign in to comment.