Skip to content

Commit

Permalink
1. extend ActorDistributionRnnNetwork to handle complex observations …
Browse files Browse the repository at this point in the history
…with preprocessing_layer and preprocessing_combiner

2. rename ActorDistributionRnnNetwork constructor parameter categorical_projection_net as discrete_projection_net, normal_projection_net as continuous_projection_net to be consistent with ActorDistributionNetwork

PiperOrigin-RevId: 260602283
Change-Id: Iebb1eaab73bafdcb775623c62fd2b35dc599fbd4
  • Loading branch information
TF-Agents Team authored and copybara-github committed Jul 29, 2019
1 parent 3072d16 commit 6e7ef80
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 112 deletions.
20 changes: 12 additions & 8 deletions tf_agents/agents/sac/sac_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class SacAgentTest(tf.test.TestCase):
def setUp(self):
super(SacAgentTest, self).setUp()
tf.compat.v1.enable_resource_variables()
self._obs_spec = [tensor_spec.TensorSpec([2], tf.float32)]
self._obs_spec = tensor_spec.TensorSpec([2], tf.float32)
self._time_step_spec = ts.time_step_spec(self._obs_spec)
self._action_spec = tensor_spec.BoundedTensorSpec([1], tf.float32, -1, 1)

Expand All @@ -126,13 +126,13 @@ def testCriticLoss(self):
alpha_optimizer=None,
actor_policy_ctor=DummyActorPolicy)

observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)]
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
time_steps = ts.restart(observations)
actions = tf.constant([[5], [6]], dtype=tf.float32)

rewards = tf.constant([10, 20], dtype=tf.float32)
discounts = tf.constant([0.9, 0.9], dtype=tf.float32)
next_observations = [tf.constant([[5, 6], [7, 8]], dtype=tf.float32)]
next_observations = tf.constant([[5, 6], [7, 8]], dtype=tf.float32)
next_time_steps = ts.transition(next_observations, rewards, discounts)

td_targets = [7.3, 19.1]
Expand Down Expand Up @@ -165,7 +165,7 @@ def testActorLoss(self):
alpha_optimizer=None,
actor_policy_ctor=DummyActorPolicy)

observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)]
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
time_steps = ts.restart(observations, batch_size=2)

expected_loss = (2 * 10 - (2 + 1) - (4 + 1)) / 2
Expand All @@ -187,7 +187,7 @@ def testAlphaLoss(self):
target_entropy=3.0,
initial_log_alpha=4.0,
actor_policy_ctor=DummyActorPolicy)
observations = [tf.constant([[1, 2], [3, 4]], dtype=tf.float32)]
observations = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
time_steps = ts.restart(observations, batch_size=2)

expected_loss = 4.0 * (-10 - 3)
Expand All @@ -208,7 +208,7 @@ def testPolicy(self):
alpha_optimizer=None,
actor_policy_ctor=DummyActorPolicy)

observations = [tf.constant([1, 2], dtype=tf.float32)]
observations = tf.constant([1, 2], dtype=tf.float32)
time_steps = ts.restart(observations)
action_step = agent.policy.action(time_steps)

Expand Down Expand Up @@ -259,10 +259,10 @@ def testTrainWithRnn(self):
step_type=tf.constant([[1] * 3] * batch_size, dtype=tf.int32),
reward=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
discount=tf.constant([[1] * 3] * batch_size, dtype=tf.float32),
observation=[observations])
observation=observations)

experience = trajectory.Trajectory(
time_steps.step_type, [observations], actions, (),
time_steps.step_type, observations, actions, (),
time_steps.step_type, time_steps.reward, time_steps.discount)

# Force variable creation.
Expand All @@ -276,3 +276,7 @@ def testTrainWithRnn(self):
self.assertEqual(self.evaluate(counter), 0)
self.evaluate(loss)
self.assertEqual(self.evaluate(counter), 1)


if __name__ == '__main__':
tf.test.main()
157 changes: 56 additions & 101 deletions tf_agents/networks/actor_distribution_rnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
from __future__ import division
from __future__ import print_function

import functools

import gin
import numpy as np
import tensorflow as tf

from tf_agents.networks import categorical_projection_network
from tf_agents.networks import dynamic_unroll_layer
from tf_agents.networks import lstm_encoding_network
from tf_agents.networks import network
from tf_agents.networks import normal_projection_network
from tf_agents.networks import utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step
from tf_agents.utils import nest_utils


Expand Down Expand Up @@ -64,14 +61,17 @@ class ActorDistributionRnnNetwork(network.DistributionNetwork):
def __init__(self,
input_tensor_spec,
output_tensor_spec,
preprocessing_layers=None,
preprocessing_combiner=None,
conv_layer_params=None,
input_fc_layer_params=(200, 100),
input_dropout_layer_params=None,
output_fc_layer_params=(200, 100),
conv_layer_params=None,
lstm_size=(40,),
output_fc_layer_params=(200, 100),
activation_fn=tf.keras.activations.relu,
categorical_projection_net=_categorical_projection_net,
normal_projection_net=_normal_projection_net,
dtype=tf.float32,
discrete_projection_net=_categorical_projection_net,
continuous_projection_net=_normal_projection_net,
name='ActorDistributionRnnNetwork'):
"""Creates an instance of `ActorDistributionRnnNetwork`.
Expand All @@ -80,6 +80,18 @@ def __init__(self,
input.
output_tensor_spec: A nest of `tensor_spec.BoundedTensorSpec` representing
the output.
preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
representing preprocessing for the different observations.
All of these layers must not be already built. For more details see
the documentation of `networks.EncodingNetwork`.
preprocessing_combiner: (Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include
`tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`.
This layer must not be already built. For more details see
the documentation of `networks.EncodingNetwork`.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
input_fc_layer_params: Optional list of fully_connected parameters, where
each item is the number of units in the layer. This is applied before
the LSTM cell.
Expand All @@ -92,74 +104,55 @@ def __init__(self,
after each fully connected layer, except if the entry in the list is
None. This list must have the same length of input_fc_layer_params, or
be None.
lstm_size: An iterable of ints specifying the LSTM cell sizes to use.
output_fc_layer_params: Optional list of fully_connected parameters, where
each item is the number of units in the layer. This is applied after the
LSTM cell.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
lstm_size: An iterable of ints specifying the LSTM cell sizes to use.
activation_fn: Activation function, e.g. tf.nn.relu, slim.leaky_relu, ...
categorical_projection_net: Callable that generates a categorical
projection network to be called with some hidden state and the
outer_rank of the state.
normal_projection_net: Callable that generates a normal projection network
to be called with some hidden state and the outer_rank of the state.
dtype: The dtype to use by the convolution and fully connected layers.
discrete_projection_net: Callable that generates a discrete projection
network to be called with some hidden state and the outer_rank of the
state.
continuous_projection_net: Callable that generates a continuous projection
network to be called with some hidden state and the outer_rank of the
state.
name: A string representing name of the network.
Raises:
ValueError: If `input_tensor_spec` contains more than one observation.
ValueError: If 'input_dropout_layer_params' is not None.
"""
if len(tf.nest.flatten(input_tensor_spec)) > 1:
raise ValueError('Only a single observation is supported by this network')
if input_dropout_layer_params:
raise ValueError('Dropout layer is not supported.')

input_layers = utils.mlp_layers(
conv_layer_params,
input_fc_layer_params,
input_dropout_layer_params,
lstm_encoder = lstm_encoding_network.LSTMEncodingNetwork(
input_tensor_spec=input_tensor_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=preprocessing_combiner,
conv_layer_params=conv_layer_params,
input_fc_layer_params=input_fc_layer_params,
lstm_size=lstm_size,
output_fc_layer_params=output_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.glorot_uniform(),
name='input_mlp')

# Create RNN cell
if len(lstm_size) == 1:
cell = tf.keras.layers.LSTMCell(lstm_size[0])
else:
cell = tf.keras.layers.StackedRNNCells(
[tf.keras.layers.LSTMCell(size) for size in lstm_size])

state_spec = tf.nest.map_structure(
functools.partial(
tensor_spec.TensorSpec, dtype=tf.float32,
name='network_state_spec'), cell.state_size)

output_layers = utils.mlp_layers(
fc_layer_params=output_fc_layer_params, name='output')

projection_networks = []
for single_output_spec in tf.nest.flatten(output_tensor_spec):
if tensor_spec.is_discrete(single_output_spec):
projection_networks.append(
categorical_projection_net(single_output_spec))
dtype=dtype,
name=name)

def map_proj(spec):
if tensor_spec.is_discrete(spec):
return discrete_projection_net(spec)
else:
projection_networks.append(normal_projection_net(single_output_spec))
return continuous_projection_net(spec)

projection_distribution_specs = [
proj_net.output_spec for proj_net in projection_networks
]
output_spec = tf.nest.pack_sequence_as(output_tensor_spec,
projection_distribution_specs)
projection_networks = tf.nest.map_structure(map_proj, output_tensor_spec)
output_spec = tf.nest.map_structure(lambda proj_net: proj_net.output_spec,
projection_networks)

super(ActorDistributionRnnNetwork, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=state_spec,
state_spec=lstm_encoder.state_spec,
output_spec=output_spec,
name=name)

self._conv_layer_params = conv_layer_params
self._input_layers = input_layers
self._dynamic_unroll = dynamic_unroll_layer.DynamicUnroll(cell)
self._output_layers = output_layers
self._lstm_encoder = lstm_encoder
self._projection_networks = projection_networks
self._output_tensor_spec = output_tensor_spec

Expand All @@ -168,47 +161,9 @@ def output_tensor_spec(self):
return self._output_tensor_spec

def call(self, observation, step_type, network_state=None):
num_outer_dims = nest_utils.get_outer_rank(observation,
self.input_tensor_spec)
if num_outer_dims not in (1, 2):
raise ValueError(
'Input observation must have a batch or batch x time outer shape.')

has_time_dim = num_outer_dims == 2
if not has_time_dim:
# Add a time dimension to the inputs.
observation = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
observation)
step_type = tf.nest.map_structure(lambda t: tf.expand_dims(t, 1),
step_type)

states = tf.cast(tf.nest.flatten(observation)[0], tf.float32)
batch_squash = utils.BatchSquash(2) # Squash B, and T dims.
states = batch_squash.flatten(states)

for layer in self._input_layers:
states = layer(states)

states = batch_squash.unflatten(states)

with tf.name_scope('reset_mask'):
reset_mask = tf.equal(step_type, time_step.StepType.FIRST)
# Unroll over the time sequence.
states, network_state = self._dynamic_unroll(
states,
reset_mask,
initial_state=network_state)

states = batch_squash.flatten(states)

for layer in self._output_layers:
states = layer(states)

states = batch_squash.unflatten(states)
outputs = [
projection(states, num_outer_dims)
for projection in self._projection_networks
]

output_actions = tf.nest.pack_sequence_as(self._output_tensor_spec, outputs)
state, network_state = self._lstm_encoder(
observation, step_type=step_type, network_state=network_state)
outer_rank = nest_utils.get_outer_rank(observation, self.input_tensor_spec)
output_actions = tf.nest.map_structure(
lambda proj_net: proj_net(state, outer_rank), self._projection_networks)
return output_actions, network_state
38 changes: 35 additions & 3 deletions tf_agents/networks/actor_distribution_rnn_network_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@

import tensorflow as tf
from tf_agents.networks import actor_distribution_rnn_network
from tf_agents.networks import sequential_layer
from tf_agents.policies import actor_policy
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
from tensorflow.python.framework import test_util # TF internal


class ActorDistributionNetworkTest(tf.test.TestCase):

@test_util.run_in_graph_and_eager_modes()
def testBuilds(self):
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
1)
Expand Down Expand Up @@ -88,7 +87,6 @@ def testBuilds(self):
self.assertEqual((1, 3), network_state[0].shape)
self.assertEqual((1, 3), network_state[1].shape)

@test_util.run_in_graph_and_eager_modes()
def testRunsWithLstmStack(self):
observation_spec = tensor_spec.BoundedTensorSpec((8, 8, 3), tf.float32, 0,
1)
Expand Down Expand Up @@ -117,6 +115,40 @@ def testRunsWithLstmStack(self):
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(tf.nest.map_structure(lambda d: d.sample(), net_call[0]))

def testHandlePreprocessingLayers(self):
observation_spec = (tensor_spec.TensorSpec([1], tf.float32),
tensor_spec.TensorSpec([], tf.float32))
time_step_spec = ts.time_step_spec(observation_spec)
time_step = tensor_spec.sample_spec_nest(time_step_spec, outer_dims=(3, 4))

action_spec = [
tensor_spec.BoundedTensorSpec((2,), tf.float32, 2, 3),
tensor_spec.BoundedTensorSpec((3,), tf.int32, 0, 3)
]

preprocessing_layers = (tf.keras.layers.Dense(4),
sequential_layer.SequentialLayer([
tf.keras.layers.Reshape((1,)),
tf.keras.layers.Dense(4)
]))

net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
observation_spec,
action_spec,
preprocessing_layers=preprocessing_layers,
preprocessing_combiner=tf.keras.layers.Add())

initial_state = actor_policy.ActorPolicy(time_step_spec, action_spec,
net).get_initial_state(3)

action_distributions, _ = net(time_step.observation, time_step.step_type,
initial_state)

self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertEqual([3, 4, 2], action_distributions[0].mode().shape.as_list())
self.assertEqual([3, 4, 3], action_distributions[1].mode().shape.as_list())
self.assertGreater(len(net.trainable_variables), 4)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 6e7ef80

Please sign in to comment.