diff --git a/tf_agents/utils/nest_utils_test.py b/tf_agents/utils/nest_utils_test.py index c7ede7f09..adcb91f7a 100644 --- a/tf_agents/utils/nest_utils_test.py +++ b/tf_agents/utils/nest_utils_test.py @@ -97,6 +97,28 @@ def zeros_from_spec(self, spec, batch_size=None, extra_sizes=None): return tf.nest.pack_sequence_as(spec, tensors) + def placeholders_from_spec(self, spec): + """Return tensors matching spec with an added unknown batch dimension. + + Args: + spec: A `tf.TypeSpec`, e.g. `tf.TensorSpec` or `tf.SparseTensorSpec`. + + Returns: + A possibly nested tuple of Tensors matching the spec. + """ + tensors = [] + for s in tf.nest.flatten(spec): + if isinstance(s, tf.SparseTensorSpec): + raise NotImplementedError( + 'Support for SparseTensor placeholders not implemented.') + elif isinstance(s, tf.TensorSpec): + shape = tf.TensorShape([None]).concatenate(s.shape) + tensors.append(tf.placeholder(dtype=s.dtype, shape=shape)) + else: + raise TypeError('Unexpected spec type: {}'.format(s)) + + return tf.nest.pack_sequence_as(spec, tensors) + def testGetOuterShapeNotBatched(self): tensor = tf.zeros([2, 3], dtype=tf.float32) spec = tensor_spec.TensorSpec([2, 3], dtype=tf.float32) @@ -522,6 +544,22 @@ def testFlattenMultiBatchedNestedTensors(self): batch_dims_ = self.evaluate(batch_dims) self.assertAllEqual(batch_dims_, [7, 5]) + def testFlattenMultiBatchedNestedTensorsWithPartiallyKnownShape(self): + if tf.executing_eagerly(): + self.skipTest('Do not check nest processing of data in eager mode. ' + 'Placeholders are not compatible with eager execution.') + shape = [2, 3] + specs = self.nest_spec(shape, include_sparse=False) + tensors = self.placeholders_from_spec(specs) + + (batch_flattened_tensors, + _) = nest_utils.flatten_multi_batched_nested_tensors( + tensors, specs) + + tf.nest.assert_same_structure(specs, batch_flattened_tensors) + assert_shapes = lambda t: self.assertEqual(t.shape.as_list(), [None, 2, 3]) + tf.nest.map_structure(assert_shapes, batch_flattened_tensors) + class NestedArraysTest(tf.test.TestCase): """Tests functions related to nested arrays."""