Skip to content

Commit

Permalink
Test features are pickable (for Beam). Fix where needed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 245906912
  • Loading branch information
pierrot0 authored and copybara-github committed Apr 30, 2019
1 parent b1734ba commit 24d683c
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

REQUIRED_PKGS = [
'absl-py',
'dill', # TODO(tfds): move to TESTS_REQUIRE.
'future',
'numpy',
'promise',
Expand All @@ -54,9 +55,9 @@
]

TESTS_REQUIRE = [
'apache-beam',
'jupyter',
'pytest',
'apache-beam',
]

if sys.version_info.major == 3:
Expand Down
15 changes: 8 additions & 7 deletions tensorflow_datasets/core/features/image_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, shape=None, encoding_format=None):
"""
self._encoding_format = None
self._shape = None
self._runner = None

# Set and validate values
self.set_encoding_format(encoding_format or 'png')
Expand Down Expand Up @@ -118,20 +119,20 @@ def get_serialized_info(self):
# Only store raw image (includes size).
return tf.io.FixedLenFeature(tuple(), tf.string)

@utils.memoized_property
def _runner(self):
# TODO(epot): Should clear the runner once every image has been encoded.
# TODO(epot): Better support for multi-shape image (instead of re-building
# a new graph every time)
return utils.TFGraphRunner()

def _encode_image(self, np_image):
"""Returns np_image encoded as jpeg or png."""
if not self._runner:
self._runner = utils.TFGraphRunner()
if np_image.dtype != np.uint8:
raise ValueError('Image should be uint8. Detected: %s.' % np_image.dtype)
utils.assert_shape_match(np_image.shape, self._shape)
return self._runner.run(ENCODE_FN[self._encoding_format], np_image)

def __getstate__(self):
state = self.__dict__.copy()
state['_runner'] = None
return state

def encode_example(self, image_or_path_or_fobj):
"""Convert the given image into a dict convertible to tf example."""
if isinstance(image_or_path_or_fobj, np.ndarray):
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_datasets/core/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ def __getattr__(self, key):
"""Allow to access the underlying attributes directly."""
return getattr(self._seq_feature['inner'], key)

# The __getattr__ method triggers an infinite recursion loop when loading a
# pickled instance. So we override that name in the instance dict, and remove
# it when unplickling.
def __getstate__(self):
state = self.__dict__.copy()
state['__getattr__'] = 0
return state

def __setstate__(self, state):
del state['__getattr__']
self.__dict__.update(state)

def get_tensor_info(self):
return self._seq_feature.get_tensor_info()['inner']

Expand Down
5 changes: 3 additions & 2 deletions tensorflow_datasets/core/features/text/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ def __init__(self, alphanum_only=True, reserved_tokens=None):
reserved_tokens, self._reserved_tokens_re = _prepare_reserved_tokens(
reserved_tokens)
self._reserved_tokens = set(reserved_tokens)
self._alphanum_re = ALPHANUM_REGEX if self._alphanum_only else ALL_REGEX

@property
def alphanum_only(self):
Expand All @@ -389,8 +388,10 @@ def tokenize(self, s):
for substr in substrs:
if substr in self.reserved_tokens:
toks.append(substr)
elif self._alphanum_only:
toks.extend(ALPHANUM_REGEX.split(substr))
else:
toks.extend(self._alphanum_re.split(substr))
toks.extend(ALL_REGEX.split(substr))

# Filter out empty strings
toks = [t for t in toks if t]
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_datasets/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from absl.testing import absltest

import dill
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -220,6 +221,9 @@ def assertFeature(self, feature, shape, dtype, tests, serialized_info=None):
def assertFeatureTest(self, fdict, test, feature, shape, dtype):
"""Test that encode=>decoding of a value works correctly."""

# test feature.encode_example can be pickled and unpickled for beam.
dill.loads(dill.dumps(feature.encode_example))

# self._process_subtest_exp(e)
input_value = {"inner": test.value}

Expand Down

0 comments on commit 24d683c

Please sign in to comment.