Skip to content

Commit

Permalink
extend ValueNetwork to handle complex observations with preprocessing…
Browse files Browse the repository at this point in the history
…_layer and preprocessing_combiner

PiperOrigin-RevId: 259810453
Change-Id: If72917d84513c3ae703ccf6d9efad8c9ff6c8baa
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jul 24, 2019
1 parent e59f6bf commit ed17272
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 37 deletions.
82 changes: 47 additions & 35 deletions tf_agents/networks/value_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Implements a network that will generate the following layers:
[optional]: preprocessing_layers # preprocessing_layers
[optional]: (Add | Concat(axis=-1) | ...) # preprocessing_combiner
[optional]: Conv2D # conv_layer_params
Flatten
[optional]: Dense # fc_layer_params
Expand All @@ -30,9 +32,8 @@
import gin
import tensorflow as tf

from tf_agents.networks import encoding_network
from tf_agents.networks import network
from tf_agents.networks import utils
from tf_agents.utils import nest_utils


@gin.configurable
Expand All @@ -41,10 +42,15 @@ class ValueNetwork(network.Network):

def __init__(self,
input_tensor_spec,
preprocessing_layers=None,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=(75, 40),
dropout_layer_params=None,
conv_layer_params=None,
activation_fn=tf.keras.activations.relu,
kernel_initializer=None,
batch_squash=True,
dtype=tf.float32,
name='ValueNetwork'):
"""Creates an instance of `ValueNetwork`.
Expand All @@ -54,6 +60,18 @@ def __init__(self,
Args:
input_tensor_spec: A `tensor_spec.TensorSpec` or a tuple of specs
representing the input observations.
preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
representing preprocessing for the different observations.
All of these layers must not be already built. For more details see
the documentation of `networks.EncodingNetwork`.
preprocessing_combiner: (Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include
`tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`.
This layer must not be already built. For more details see
the documentation of `networks.EncodingNetwork`.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
fc_layer_params: Optional list of fully_connected parameters, where each
item is the number of units in the layer.
dropout_layer_params: Optional list of dropout layer parameters, each item
Expand All @@ -64,52 +82,46 @@ def __init__(self,
the fully connected layers; there is a dropout layer after each fully
connected layer, except if the entry in the list is None. This list must
have the same length of fc_layer_params, or be None.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
activation_fn: Activation function, e.g. tf.keras.activations.relu,.
kernel_initializer: Initializer to use for the kernels of the conv and
dense layers. If none is provided a default variance_scaling_initializer
batch_squash: If True the outer_ranks of the observation are squashed into
the batch dimension. This allow encoding networks to be used with
observations with shape [BxTx...].
dtype: The dtype to use by the convolution and fully connected layers.
name: A string representing name of the network.
Raises:
ValueError: If input_tensor_spec is not an instance of network.InputSpec.
ValueError: If `input_tensor_spec.observations` contains more than one
observation.
"""
super(ValueNetwork, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=(),
name=name)

if len(tf.nest.flatten(input_tensor_spec)) > 1:
raise ValueError(
'Network only supports observation specs with a single observation.')
if not kernel_initializer:
kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform()

self._postprocessing_layers = utils.mlp_layers(
conv_layer_params,
fc_layer_params,
self._encoder = encoding_network.EncodingNetwork(
input_tensor_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=preprocessing_combiner,
conv_layer_params=conv_layer_params,
fc_layer_params=fc_layer_params,
dropout_layer_params=dropout_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.glorot_uniform(),
name='input_mlp')
kernel_initializer=kernel_initializer,
batch_squash=batch_squash,
dtype=dtype)

self._postprocessing_layers.append(
tf.keras.layers.Dense(
1,
activation=None,
kernel_initializer=tf.compat.v1.initializers.random_uniform(
minval=-0.03, maxval=0.03),
))
self._postprocessing_layers = tf.keras.layers.Dense(
1,
activation=None,
kernel_initializer=tf.compat.v1.initializers.random_uniform(
minval=-0.03, maxval=0.03))

def call(self, observation, step_type=None, network_state=()):
outer_rank = nest_utils.get_outer_rank(observation,
self.input_tensor_spec)
batch_squash = utils.BatchSquash(outer_rank)

states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
states = batch_squash.flatten(states)
for layer in self._postprocessing_layers:
states = layer(states)

value = tf.reshape(states, [-1])
value = batch_squash.unflatten(value)
return value, network_state
state, network_state = self._encoder(
observation, step_type=step_type, network_state=network_state)
value = self._postprocessing_layers(state)
return tf.squeeze(value, -1), network_state
28 changes: 26 additions & 2 deletions tf_agents/networks/value_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class ValueNetworkTest(tf.test.TestCase):

@test_util.run_in_graph_and_eager_modes()
def testBuilds(self):
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.int32, 0, 1)
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
1)
observation = tensor_spec.sample_spec_nest(
observation_spec, outer_dims=(1,))

Expand Down Expand Up @@ -59,7 +60,8 @@ def testBuilds(self):

@test_util.run_in_graph_and_eager_modes()
def testHandlesExtraOuterDims(self):
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.int32, 0, 1)
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
1)
observation = tensor_spec.sample_spec_nest(
observation_spec, outer_dims=(3, 3, 2))

Expand All @@ -69,6 +71,28 @@ def testHandlesExtraOuterDims(self):
value, _ = net(observation)
self.assertEqual([3, 3, 2], value.shape.as_list())

@test_util.run_in_graph_and_eager_modes()
def testHandlePreprocessingLayers(self):
observation_spec = (tensor_spec.TensorSpec([1], tf.float32),
tensor_spec.TensorSpec([], tf.float32))
observation = tensor_spec.sample_spec_nest(
observation_spec, outer_dims=(3,))

preprocessing_layers = (tf.keras.layers.Dense(4),
tf.keras.Sequential([
tf.keras.layers.Reshape((1,)),
tf.keras.layers.Dense(4)
]))

net = value_network.ValueNetwork(
observation_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=tf.keras.layers.Add())

value, _ = net(observation)
self.assertEqual([3], value.shape.as_list())
self.assertGreater(len(net.trainable_variables), 4)


if __name__ == '__main__':
tf.test.main()

0 comments on commit ed17272

Please sign in to comment.