Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2663805

Browse files
koz4klukaszkaiser
authored andcommitted
Temperature (#1277)
* Use T2TModel for policies * Implement sampling with temperature from policy * Fixes
1 parent 7de6344 commit 2663805

File tree

10 files changed

+210
-121
lines changed

10 files changed

+210
-121
lines changed

tensor2tensor/models/research/rl.py

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from tensor2tensor.rl.envs.simulated_batch_env import SimulatedBatchEnv
2828
from tensor2tensor.rl.envs.simulated_batch_gym_env import SimulatedBatchGymEnv
2929
from tensor2tensor.utils import registry
30+
from tensor2tensor.utils import t2t_model
3031

3132
import tensorflow as tf
3233
import tensorflow_probability as tfp
@@ -60,12 +61,12 @@ def ppo_base_v1():
6061
return hparams
6162

6263

63-
@registry.register_hparams
64-
def ppo_continuous_action_base():
65-
hparams = ppo_base_v1()
66-
hparams.add_hparam("policy_network", feed_forward_gaussian_fun)
67-
hparams.add_hparam("policy_network_params", "basic_policy_parameters")
68-
return hparams
64+
#@registry.register_hparams
65+
#def ppo_continuous_action_base():
66+
# hparams = ppo_base_v1()
67+
# hparams.add_hparam("policy_network", feed_forward_gaussian_fun)
68+
# hparams.add_hparam("policy_network_params", "basic_policy_parameters")
69+
# return hparams
6970

7071

7172
@registry.register_hparams
@@ -77,14 +78,14 @@ def basic_policy_parameters():
7778
@registry.register_hparams
7879
def ppo_discrete_action_base():
7980
hparams = ppo_base_v1()
80-
hparams.add_hparam("policy_network", feed_forward_categorical_fun)
81+
hparams.add_hparam("policy_network", "feed_forward_categorical_policy")
8182
return hparams
8283

8384

8485
@registry.register_hparams
8586
def discrete_random_action_base():
8687
hparams = common_hparams.basic_params1()
87-
hparams.add_hparam("policy_network", random_policy_fun)
88+
hparams.add_hparam("policy_network", "random_policy")
8889
return hparams
8990

9091

@@ -100,7 +101,7 @@ def ppo_atari_base():
100101
hparams.value_loss_coef = 1
101102
hparams.optimization_epochs = 3
102103
hparams.epochs_num = 1000
103-
hparams.policy_network = feed_forward_cnn_small_categorical_fun
104+
hparams.policy_network = "feed_forward_cnn_small_categorical_policy"
104105
hparams.clipping_coef = 0.2
105106
hparams.optimization_batch_size = 20
106107
hparams.max_gradients_norm = 0.5
@@ -157,23 +158,36 @@ def get_policy(observations, hparams, action_space):
157158
"""Get a policy network.
158159
159160
Args:
160-
observations: Tensor with observations
161+
observations
161162
hparams: parameters
162163
action_space: action space
163164
164165
Returns:
165-
Tensor with policy and value function output
166+
Tuple (action logits, value).
166167
"""
167-
policy_network_lambda = hparams.policy_network
168-
return policy_network_lambda(action_space, hparams, observations)
168+
if not isinstance(action_space, gym.spaces.Discrete):
169+
raise ValueError("Expecting discrete action space.")
170+
171+
model = registry.model(hparams.policy_network)(
172+
hparams, tf.estimator.ModeKeys.TRAIN
173+
)
174+
obs_shape = common_layers.shape_list(observations)
175+
features = {
176+
"inputs": observations,
177+
"target_action": tf.zeros(obs_shape[:2] + [action_space.n]),
178+
"target_value": tf.zeros(obs_shape[:2])
179+
}
180+
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
181+
(targets, _) = model(features)
182+
return (targets["target_action"], targets["target_value"])
169183

170184

171185
@registry.register_hparams
172186
def ppo_pong_ae_base():
173187
"""Pong autoencoder base parameters."""
174188
hparams = ppo_original_params()
175189
hparams.learning_rate = 1e-4
176-
hparams.network = dense_bitwise_categorical_fun
190+
hparams.network = "dense_bitwise_categorical_policy"
177191
return hparams
178192

179193

@@ -225,6 +239,12 @@ def mfrl_original():
225239
batch_size=16,
226240
eval_batch_size=2,
227241
frame_stack_size=4,
242+
eval_sampling_temps=[0.0, 0.2, 0.5, 0.8, 1.0, 2.0],
243+
eval_max_num_noops=8,
244+
resize_height_factor=2,
245+
resize_width_factor=2,
246+
grayscale=0,
247+
env_timesteps_limit=-1,
228248
)
229249

230250

@@ -234,11 +254,6 @@ def mfrl_base():
234254
hparams = mfrl_original()
235255
hparams.add_hparam("ppo_epochs_num", 3000)
236256
hparams.add_hparam("ppo_eval_every_epochs", 100)
237-
hparams.add_hparam("eval_max_num_noops", 8)
238-
hparams.add_hparam("resize_height_factor", 2)
239-
hparams.add_hparam("resize_width_factor", 2)
240-
hparams.add_hparam("grayscale", 0)
241-
hparams.add_hparam("env_timesteps_limit", -1)
242257
return hparams
243258

244259

@@ -250,10 +265,18 @@ def mfrl_tiny():
250265
return hparams
251266

252267

268+
class DiscretePolicyBase(t2t_model.T2TModel):
269+
270+
@staticmethod
271+
def _get_num_actions(features):
272+
return common_layers.shape_list(features["target_action"])[2]
273+
274+
253275
NetworkOutput = collections.namedtuple(
254276
"NetworkOutput", "policy, value, action_postprocessing")
255277

256278

279+
# TODO(koz4k): Translate it to T2TModel or remove.
257280
def feed_forward_gaussian_fun(action_space, config, observations):
258281
"""Feed-forward Gaussian."""
259282
if not isinstance(action_space, gym.spaces.box.Box):
@@ -303,36 +326,40 @@ def clip_logits(logits, config):
303326
return logits
304327

305328

306-
def feed_forward_categorical_fun(action_space, config, observations):
329+
@registry.register_model
330+
class FeedForwardCategoricalPolicy(DiscretePolicyBase):
307331
"""Feed-forward categorical."""
308-
if not isinstance(action_space, gym.spaces.Discrete):
309-
raise ValueError("Expecting discrete action space.")
310-
flat_observations = tf.reshape(observations, [
311-
tf.shape(observations)[0], tf.shape(observations)[1],
312-
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
313-
with tf.variable_scope("network_parameters"):
332+
333+
def body(self, features):
334+
observations = features["inputs"]
335+
flat_observations = tf.reshape(observations, [
336+
tf.shape(observations)[0], tf.shape(observations)[1],
337+
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
314338
with tf.variable_scope("policy"):
315339
x = flat_observations
316-
for size in config.policy_layers:
340+
for size in self.hparams.policy_layers:
317341
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
318-
logits = tf.contrib.layers.fully_connected(x, action_space.n,
319-
activation_fn=None)
342+
logits = tf.contrib.layers.fully_connected(
343+
x, self._get_num_actions(features), activation_fn=None
344+
)
320345
with tf.variable_scope("value"):
321346
x = flat_observations
322-
for size in config.value_layers:
347+
for size in self.hparams.value_layers:
323348
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
324349
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
325-
logits = clip_logits(logits, config)
326-
policy = tfp.distributions.Categorical(logits=logits)
327-
return NetworkOutput(policy, value, lambda a: a)
350+
logits = clip_logits(logits, self.hparams)
351+
return {"target_action": logits, "target_value": value}
328352

329353

330-
def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
354+
@registry.register_model
355+
class FeedForwardCnnSmallCategoricalPolicy(DiscretePolicyBase):
331356
"""Small cnn network with categorical output."""
332-
obs_shape = common_layers.shape_list(observations)
333-
x = tf.reshape(observations, [-1] + obs_shape[2:])
334-
with tf.variable_scope("network_parameters"):
335-
dropout = getattr(config, "dropout_ppo", 0.0)
357+
358+
def body(self, features):
359+
observations = features["inputs"]
360+
obs_shape = common_layers.shape_list(observations)
361+
x = tf.reshape(observations, [-1] + obs_shape[2:])
362+
dropout = getattr(self.hparams, "dropout_ppo", 0.0)
336363
with tf.variable_scope("feed_forward_cnn_small"):
337364
x = tf.to_float(x) / 255.0
338365
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
@@ -346,23 +373,25 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
346373
flat_x = tf.nn.dropout(flat_x, keep_prob=1.0 - dropout)
347374
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
348375

349-
logits = tf.contrib.layers.fully_connected(x, action_space.n,
350-
activation_fn=None)
351-
logits = clip_logits(logits, config)
376+
logits = tf.contrib.layers.fully_connected(
377+
x, self._get_num_actions(features), activation_fn=None
378+
)
379+
logits = clip_logits(logits, self.hparams)
352380

353381
value = tf.contrib.layers.fully_connected(
354382
x, 1, activation_fn=None)[..., 0]
355-
policy = tfp.distributions.Categorical(logits=logits)
356-
return NetworkOutput(policy, value, lambda a: a)
383+
return {"target_action": logits, "target_value": value}
357384

358385

359-
def feed_forward_cnn_small_categorical_fun_new(
360-
action_space, config, observations):
386+
@registry.register_model
387+
class FeedForwardCnnSmallCategoricalPolicyNew(DiscretePolicyBase):
361388
"""Small cnn network with categorical output."""
362-
obs_shape = common_layers.shape_list(observations)
363-
x = tf.reshape(observations, [-1] + obs_shape[2:])
364-
with tf.variable_scope("network_parameters"):
365-
dropout = getattr(config, "dropout_ppo", 0.0)
389+
390+
def body(self, features):
391+
observations = features["inputs"]
392+
obs_shape = common_layers.shape_list(observations)
393+
x = tf.reshape(observations, [-1] + obs_shape[2:])
394+
dropout = getattr(self.hparams, "dropout_ppo", 0.0)
366395
with tf.variable_scope("feed_forward_cnn_small"):
367396
x = tf.to_float(x) / 255.0
368397
x = tf.nn.dropout(x, keep_prob=1.0 - dropout)
@@ -384,22 +413,23 @@ def feed_forward_cnn_small_categorical_fun_new(
384413
flat_x = tf.nn.dropout(flat_x, keep_prob=1.0 - dropout)
385414
x = tf.layers.dense(flat_x, 128, activation=tf.nn.relu, name="dense1")
386415

387-
logits = tf.layers.dense(x, action_space.n, name="dense2")
388-
logits = clip_logits(logits, config)
416+
logits = tf.layers.dense(
417+
x, self._get_num_actions(features), name="dense2"
418+
)
419+
logits = clip_logits(logits, self.hparams)
389420

390421
value = tf.layers.dense(x, 1, name="value")[..., 0]
391-
policy = tfp.distributions.Categorical(logits=logits)
422+
return {"target_action": logits, "target_value": value}
392423

393-
return NetworkOutput(policy, value, lambda a: a)
394424

395-
396-
def dense_bitwise_categorical_fun(action_space, config, observations):
425+
@registry.register_model
426+
class DenseBitwiseCategoricalPolicy(DiscretePolicyBase):
397427
"""Dense network with bitwise input and categorical output."""
398-
del config
399-
obs_shape = common_layers.shape_list(observations)
400-
x = tf.reshape(observations, [-1] + obs_shape[2:])
401428

402-
with tf.variable_scope("network_parameters"):
429+
def body(self, features):
430+
observations = features["inputs"]
431+
obs_shape = common_layers.shape_list(observations)
432+
x = tf.reshape(observations, [-1] + obs_shape[2:])
403433
with tf.variable_scope("dense_bitwise"):
404434
x = discretization.int_to_bit_embed(x, 8, 32)
405435
flat_x = tf.reshape(
@@ -409,22 +439,29 @@ def dense_bitwise_categorical_fun(action_space, config, observations):
409439
x = tf.contrib.layers.fully_connected(flat_x, 256, tf.nn.relu)
410440
x = tf.contrib.layers.fully_connected(flat_x, 128, tf.nn.relu)
411441

412-
logits = tf.contrib.layers.fully_connected(x, action_space.n,
413-
activation_fn=None)
442+
logits = tf.contrib.layers.fully_connected(
443+
x, self._get_num_actions(features), activation_fn=None
444+
)
414445

415446
value = tf.contrib.layers.fully_connected(
416447
x, 1, activation_fn=None)[..., 0]
417-
policy = tfp.distributions.Categorical(logits=logits)
418448

419-
return NetworkOutput(policy, value, lambda a: a)
449+
return {"target_action": logits, "target_value": value}
420450

421451

422-
def random_policy_fun(action_space, unused_config, observations):
452+
@registry.register_model
453+
class RandomPolicy(DiscretePolicyBase):
423454
"""Random policy with categorical output."""
424-
obs_shape = observations.shape.as_list()
425-
with tf.variable_scope("network_parameters"):
455+
456+
def body(self, features):
457+
observations = features["inputs"]
458+
obs_shape = observations.shape.as_list()
459+
# Just so Saver doesn't complain because of no variables.
460+
tf.get_variable("dummy_var", initializer=0.0)
461+
num_actions = self._get_num_actions(features)
462+
logits = tf.constant(
463+
1. / float(num_actions),
464+
shape=(obs_shape[:2] + [num_actions])
465+
)
426466
value = tf.zeros(obs_shape[:2])
427-
policy = tfp.distributions.Categorical(
428-
probs=[[[1. / float(action_space.n)] * action_space.n] *
429-
(obs_shape[0] * obs_shape[1])])
430-
return NetworkOutput(policy, value, lambda a: a)
467+
return {"target_action": logits, "target_value": value}

tensor2tensor/rl/dopamine_connector.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
from copy import copy
2223
from dopamine.agents.dqn import dqn_agent
2324
from dopamine.agents.dqn.dqn_agent import NATURE_DQN_OBSERVATION_SHAPE
2425
from dopamine.agents.dqn.dqn_agent import NATURE_DQN_STACK_SIZE
@@ -285,6 +286,7 @@ def train(self,
285286
simulated,
286287
save_continuously,
287288
epoch,
289+
sampling_temp=1.0,
288290
num_env_steps=None,
289291
env_step_multiplier=1,
290292
eval_env_fn=None,
@@ -294,6 +296,11 @@ def train(self,
294296
if num_env_steps is None:
295297
num_env_steps = hparams.num_frames
296298

299+
hparams = copy(hparams)
300+
hparams.set_hparam(
301+
"agent_epsilon_eval", min(hparams.agent_epsilon_eval * sampling_temp, 1)
302+
)
303+
297304
target_iterations, training_steps_per_iteration = \
298305
self._target_iteractions_and_steps(
299306
num_env_steps=num_env_steps * env_step_multiplier,
@@ -307,11 +314,14 @@ def train(self,
307314

308315
self.completed_iterations = target_iterations
309316

310-
def evaluate(self, env_fn, hparams, stochastic):
317+
def evaluate(self, env_fn, hparams, sampling_temp):
311318
target_iterations = 0
312319
training_steps_per_iteration = 0
313-
if not stochastic:
314-
hparams.set_hparam("agent_epsilon_eval", 0.)
320+
321+
hparams = copy(hparams)
322+
hparams.set_hparam(
323+
"agent_epsilon_eval", min(hparams.agent_epsilon_eval * sampling_temp, 1)
324+
)
315325

316326
create_environment_fn = get_create_env_fun(
317327
env_fn, time_limit=hparams.time_limit)

tensor2tensor/rl/policy_learner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,21 @@ def __init__(self, frame_stack_size, base_event_dir, agent_model_dir):
2929
self.agent_model_dir = agent_model_dir
3030

3131
def train(
32-
self, env_fn, hparams, simulated, save_continuously, epoch,
33-
num_env_steps=None, env_step_multiplier=1, eval_env_fn=None,
32+
self,
33+
env_fn,
34+
hparams,
35+
simulated,
36+
save_continuously,
37+
epoch,
38+
sampling_temp=1.0,
39+
num_env_steps=None,
40+
env_step_multiplier=1,
41+
eval_env_fn=None,
3442
report_fn=None
3543
):
3644
# TODO(konradczechowski): pass name_scope instead of epoch?
3745
# TODO(konradczechowski): move 'simulated' to batch_env
3846
raise NotImplementedError()
3947

40-
def evaluate(self, env_fn, hparams, stochastic):
48+
def evaluate(self, env_fn, hparams, sampling_temp):
4149
raise NotImplementedError()

0 commit comments

Comments
 (0)