Skip to content

Commit

Permalink
Move _convert_value to utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614604761
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Mar 11, 2024
1 parent 6a69719 commit 350455f
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@

from __future__ import annotations

import datetime
import functools
import io
import itertools
import multiprocessing
import os
Expand Down Expand Up @@ -121,78 +119,14 @@ def _from_tfds_to_hf(tfds_name: str) -> str:
)


def _convert_value(hf_value: Any, feature: feature_lib.FeatureConnector) -> Any:
"""Converts a Huggingface value to a TFDS compatible value."""
# See the docstring of huggingface_utils._get_default_value for explanations.
if hf_value is None:
return huggingface_utils._get_default_value(feature) # pylint: disable=protected-access
if isinstance(hf_value, datetime.datetime):
return int(hf_value.timestamp())
elif isinstance(feature, feature_lib.ClassLabel):
return hf_value
elif isinstance(feature, feature_lib.Scalar):
return hf_value
elif isinstance(feature, feature_lib.Translation):
if isinstance(hf_value, dict):
# Replaces `None` values with the default value.
return {
key: (
value
if value is not None
else huggingface_utils._get_default_value(feature[key]) # pylint: disable=protected-access
)
for key, value in hf_value.items()
}
elif isinstance(feature, feature_lib.FeaturesDict):
if isinstance(hf_value, dict):
return {k: _convert_value(v, feature[k]) for k, v in hf_value.items()}
raise ValueError(f"The feature is {feature}, but the value is: {hf_value}")
elif isinstance(feature, feature_lib.Sequence):
if isinstance(hf_value, dict):
# Should be a dict of lists:
return {
k: [_convert_value(el, feature.feature[k]) for el in v]
for k, v in hf_value.items()
}
if isinstance(hf_value, list):
return [_convert_value(v, feature.feature) for v in hf_value]
else:
return [hf_value]
elif isinstance(feature, feature_lib.Audio):
assert isinstance(hf_value, dict), f"Audio {hf_value} should be a dict"
if "array" in hf_value:
sample_rate = feature.sample_rate
# Hugging Face uses float, TFDS uses integers.
return [int(s * sample_rate) for s in hf_value["array"]]
if "path" in hf_value:
path = epath.Path(hf_value["path"])
if path.exists():
return path
else:
raise ValueError(f"{hf_value} is not a valid audio feature.")
elif isinstance(hf_value, lazy_imports_lib.lazy_imports.PIL_Image.Image):
buffer = io.BytesIO()
if hf_value.mode == "CMYK":
# Convert CMYK images to RGB.
hf_value = hf_value.convert("RGB")
hf_value.save(fp=buffer, format=_IMAGE_ENCODING_FORMAT)
return buffer.getvalue()
elif isinstance(feature, feature_lib.Tensor):
return hf_value
raise ValueError(
f"Type {type(hf_value)} of value {hf_value} "
f"for feature {type(feature)} is not supported."
)


def _convert_example(
index: int,
example: Mapping[str, Any],
features: feature_lib.FeaturesDict,
) -> Tuple[int, Mapping[str, Any]]:
"""Converts an example from Huggingface format to TFDS format."""
converted_example = {
name: _convert_value(value, features[name])
name: huggingface_utils.convert_hf_value(value, features[name])
for name, value in example.items()
}
return index, converted_example
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
from unittest import mock

from absl import logging
import numpy as np
import pytest
import tensorflow as tf
from tensorflow_datasets.core import features as feature_lib
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import registered
Expand Down Expand Up @@ -64,84 +62,6 @@ def test_from_tfds_to_hf():
assert huggingface_dataset_builder._from_tfds_to_hf("z")


def test_convert_value_datetime():
feature = feature_lib.Scalar(dtype=np.int64)
epoch_start = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
assert huggingface_dataset_builder._convert_value(epoch_start, feature) == 0
assert (
huggingface_dataset_builder._convert_value(
datetime.datetime(1970, 1, 2, tzinfo=datetime.timezone.utc), feature
)
== 86400
)


def test_convert_value_scalar():
int64_feature = feature_lib.Scalar(dtype=np.int64)
assert huggingface_dataset_builder._convert_value(42, int64_feature) == 42

int32_feature = feature_lib.Scalar(dtype=np.int32)
assert huggingface_dataset_builder._convert_value(42, int32_feature) == 42

string_feature = feature_lib.Scalar(dtype=np.object_)
assert (
huggingface_dataset_builder._convert_value("abc", string_feature) == "abc"
)

bool_feature = feature_lib.Scalar(dtype=np.bool_)
assert huggingface_dataset_builder._convert_value(True, bool_feature)
assert not huggingface_dataset_builder._convert_value(False, bool_feature)

float_feature = feature_lib.Scalar(dtype=np.float32)
assert huggingface_dataset_builder._convert_value(42.0, float_feature) == 42.0


def test_convert_value_sequence():
sequence_feature = feature_lib.Sequence(feature=tf.int64)
assert huggingface_dataset_builder._convert_value([42], sequence_feature) == [
42
]
assert huggingface_dataset_builder._convert_value(42, sequence_feature) == [
42
]
assert (
huggingface_dataset_builder._convert_value(None, sequence_feature) == [] # pylint: disable=g-explicit-bool-comparison
)


def test_convert_value_empty_sequence():
assert huggingface_dataset_builder._convert_value(
[None, "string"], feature_lib.Sequence(feature=np.str_)
) == [b"", "string"]


def test_convert_value_sequence_of_dict():
sequence_feature = feature_lib.Sequence(
{"someint": feature_lib.Scalar(dtype=np.str_)}
)
assert huggingface_dataset_builder._convert_value(
{"someint": [None, "string", None]}, sequence_feature
) == {"someint": [b"", "string", b""]}


def test_convert_value_image():
image_feature = feature_lib.Image()
image = lazy_imports_lib.lazy_imports.PIL_Image.new(mode="RGB", size=(4, 4))
assert huggingface_dataset_builder._convert_value(image, image_feature)


def test_convert_value_dict():
translation_feature = feature_lib.Translation(languages=["en", "fr", "de"])
translation = {
"de": b"Hallo Welt",
"en": b"Hello world",
"fr": None, # Hugging Face supports `None` values
}
assert huggingface_dataset_builder._convert_value(
translation, translation_feature
) == {"de": b"Hallo Welt", "en": b"Hello world", "fr": b""}


def test_remove_empty_splits():
splits = {"non_empty_split": range(5), "empty_split": range(0)}
with mock.patch.object(logging, "log"):
Expand Down
8 changes: 1 addition & 7 deletions tensorflow_datasets/core/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def encode_image_or_path(self, image_or_path_or_fobj):
encoded_image = self._encode_image(image_or_path_or_fobj)
elif isinstance(image_or_path_or_fobj, epath.PathLikeCls):
image_or_path_or_fobj = os.fspath(image_or_path_or_fobj)
with tf.io.gfile.GFile(image_or_path_or_fobj, 'rb') as image_f:
encoded_image = image_f.read()
encoded_image = epath.Path(image_or_path_or_fobj).read_bytes()
elif isinstance(image_or_path_or_fobj, bytes):
encoded_image = image_or_path_or_fobj
elif PIL_Image is not None and isinstance(
Expand Down Expand Up @@ -145,11 +144,6 @@ def _encode_pil_image(self, pil_image) -> bytes:
"""
check_pil_import_or_raise_error()
buffer = io.BytesIO()
if self.encoding_format and pil_image.format != self.encoding_format:
raise ValueError(
f'PIL Image format {pil_image.format} does not match encoding format '
f'{self.encoding_format}'
)
pil_image.save(buffer, format=self.encoding_format or pil_image.format)
return buffer.getvalue()

Expand Down
82 changes: 76 additions & 6 deletions tensorflow_datasets/core/utils/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
"""Utility functions for huggingface_dataset_builder."""

from collections.abc import Mapping, Sequence
import datetime
from typing import Any, Type

from etils import epath
import immutabledict
import numpy as np
from tensorflow_datasets.core import features as feature_lib
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core.utils import dtype_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

Expand All @@ -42,7 +45,7 @@ def convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
hf_dtype: Huggingface dtype.
Raises:
ValueError: If couldn't recognize Huggingface dtype.
TypeError: If couldn't recognize Huggingface dtype.
"""
if np_dtype := _HF_DTYPE_TO_NP_DTYPE.get(hf_dtype):
return np_dtype
Expand All @@ -54,7 +57,7 @@ def convert_to_np_dtype(hf_dtype: str) -> Type[np.generic]:
elif hasattr(tf.dtypes, hf_dtype):
return getattr(tf.dtypes, hf_dtype)
else:
raise ValueError(
raise TypeError(
f'Unrecognized type {hf_dtype}. Please open an issue if you think '
'this is a bug.'
)
Expand Down Expand Up @@ -83,13 +86,13 @@ def _get_default_value(
feature: The TFDS feature from which we want the default value.
Raises:
ValueError: If couldn't recognize feature dtype.
TypeError: If couldn't recognize feature dtype.
"""
match feature:
case feature_lib.FeaturesDict():
return {
name: _get_default_value(sub_feature)
for name, sub_feature in feature.items()
name: _get_default_value(inner_feature)
for name, inner_feature in feature.items()
}
case feature_lib.Sequence():
return []
Expand All @@ -103,7 +106,74 @@ def _get_default_value(
elif dtype_utils.is_bool(feature.np_dtype):
return False
else:
raise ValueError(f'Could not get default value for {feature}')
raise TypeError(f'Could not recognize the dtype of {feature}')


def convert_hf_value(
hf_value: Any, feature: feature_lib.FeatureConnector
) -> Any:
"""Converts Huggingface value to a TFDS compatible value.
Args:
hf_value: Huggingface value.
feature: The TFDS feature for which we want the compatible value.
Returns:
The TFDS compatible value.
Raises:
TypeError: If couldn't recognize the given feature type.
"""
match hf_value:
case None:
return _get_default_value(feature)
case datetime.datetime():
return int(hf_value.timestamp())

match feature:
case feature_lib.ClassLabel() | feature_lib.Scalar():
return hf_value
case feature_lib.FeaturesDict():
return {
name: convert_hf_value(hf_value.get(name), inner_feature)
for name, inner_feature in feature.items()
}
case feature_lib.Sequence():
match hf_value:
case dict():
# Should be a dict of lists:
return {
name: [
convert_hf_value(inner_hf_value, inner_feature)
for inner_hf_value in hf_value.get(name)
]
for name, inner_feature in feature.feature.items()
}
case list():
return [
convert_hf_value(inner_hf_value, feature.feature)
for inner_hf_value in hf_value
]
case _:
return [hf_value]
case feature_lib.Audio():
if array := hf_value.get('array'):
# Hugging Face uses floats, TFDS uses integers.
return [int(sample * feature.sample_rate) for sample in array]
elif (path := hf_value.get('path')) and (
path := epath.Path(path)
).exists():
return path
case feature_lib.Image():
hf_value: lazy_imports_lib.lazy_imports.PIL_Image.Image
# Ensure RGB format for PNG encoding.
return hf_value.convert('RGB')
case feature_lib.Tensor():
return hf_value

raise TypeError(
f'Conversion of value {hf_value} to feature {feature} is not supported.'
)


def convert_hf_dataset_name(hf_dataset_name: str) -> str:
Expand Down
Loading

0 comments on commit 350455f

Please sign in to comment.