Skip to content

Commit

Permalink
[TF Agents] This tests the flatten_multi_batched_nested_tensors() met…
Browse files Browse the repository at this point in the history
…hod with placeholders with partially unknown shapes (i.e., for the batch dimensions).

PiperOrigin-RevId: 262440810
Change-Id: I726eed4780937fd94eb826461463471fd8f4b32e
  • Loading branch information
TF-Agents Team authored and copybara-github committed Aug 8, 2019
1 parent 473b155 commit 08ba598
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tf_agents/utils/nest_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def zeros_from_spec(self, spec, batch_size=None, extra_sizes=None):

return tf.nest.pack_sequence_as(spec, tensors)

def placeholders_from_spec(self, spec):
"""Return tensors matching spec with an added unknown batch dimension.
Args:
spec: A `tf.TypeSpec`, e.g. `tf.TensorSpec` or `tf.SparseTensorSpec`.
Returns:
A possibly nested tuple of Tensors matching the spec.
"""
tensors = []
for s in tf.nest.flatten(spec):
if isinstance(s, tf.SparseTensorSpec):
raise NotImplementedError(
'Support for SparseTensor placeholders not implemented.')
elif isinstance(s, tf.TensorSpec):
shape = tf.TensorShape([None]).concatenate(s.shape)
tensors.append(tf.placeholder(dtype=s.dtype, shape=shape))
else:
raise TypeError('Unexpected spec type: {}'.format(s))

return tf.nest.pack_sequence_as(spec, tensors)

def testGetOuterShapeNotBatched(self):
tensor = tf.zeros([2, 3], dtype=tf.float32)
spec = tensor_spec.TensorSpec([2, 3], dtype=tf.float32)
Expand Down Expand Up @@ -522,6 +544,22 @@ def testFlattenMultiBatchedNestedTensors(self):
batch_dims_ = self.evaluate(batch_dims)
self.assertAllEqual(batch_dims_, [7, 5])

def testFlattenMultiBatchedNestedTensorsWithPartiallyKnownShape(self):
if tf.executing_eagerly():
self.skipTest('Do not check nest processing of data in eager mode. '
'Placeholders are not compatible with eager execution.')
shape = [2, 3]
specs = self.nest_spec(shape, include_sparse=False)
tensors = self.placeholders_from_spec(specs)

(batch_flattened_tensors,
_) = nest_utils.flatten_multi_batched_nested_tensors(
tensors, specs)

tf.nest.assert_same_structure(specs, batch_flattened_tensors)
assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [None, 2, 3])
tf.nest.map_structure(assert_shapes, batch_flattened_tensors)


class NestedArraysTest(tf.test.TestCase):
"""Tests functions related to nested arrays."""
Expand Down

0 comments on commit 08ba598

Please sign in to comment.