27
27
from tensor2tensor .rl .envs .simulated_batch_env import SimulatedBatchEnv
28
28
from tensor2tensor .rl .envs .simulated_batch_gym_env import SimulatedBatchGymEnv
29
29
from tensor2tensor .utils import registry
30
+ from tensor2tensor .utils import t2t_model
30
31
31
32
import tensorflow as tf
32
33
import tensorflow_probability as tfp
@@ -60,12 +61,12 @@ def ppo_base_v1():
60
61
return hparams
61
62
62
63
63
- @registry .register_hparams
64
- def ppo_continuous_action_base ():
65
- hparams = ppo_base_v1 ()
66
- hparams .add_hparam ("policy_network" , feed_forward_gaussian_fun )
67
- hparams .add_hparam ("policy_network_params" , "basic_policy_parameters" )
68
- return hparams
64
+ # @registry.register_hparams
65
+ # def ppo_continuous_action_base():
66
+ # hparams = ppo_base_v1()
67
+ # hparams.add_hparam("policy_network", feed_forward_gaussian_fun)
68
+ # hparams.add_hparam("policy_network_params", "basic_policy_parameters")
69
+ # return hparams
69
70
70
71
71
72
@registry .register_hparams
@@ -77,14 +78,14 @@ def basic_policy_parameters():
77
78
@registry .register_hparams
78
79
def ppo_discrete_action_base ():
79
80
hparams = ppo_base_v1 ()
80
- hparams .add_hparam ("policy_network" , feed_forward_categorical_fun )
81
+ hparams .add_hparam ("policy_network" , "feed_forward_categorical_policy" )
81
82
return hparams
82
83
83
84
84
85
@registry .register_hparams
85
86
def discrete_random_action_base ():
86
87
hparams = common_hparams .basic_params1 ()
87
- hparams .add_hparam ("policy_network" , random_policy_fun )
88
+ hparams .add_hparam ("policy_network" , "random_policy" )
88
89
return hparams
89
90
90
91
@@ -100,7 +101,7 @@ def ppo_atari_base():
100
101
hparams .value_loss_coef = 1
101
102
hparams .optimization_epochs = 3
102
103
hparams .epochs_num = 1000
103
- hparams .policy_network = feed_forward_cnn_small_categorical_fun
104
+ hparams .policy_network = "feed_forward_cnn_small_categorical_policy"
104
105
hparams .clipping_coef = 0.2
105
106
hparams .optimization_batch_size = 20
106
107
hparams .max_gradients_norm = 0.5
@@ -157,23 +158,36 @@ def get_policy(observations, hparams, action_space):
157
158
"""Get a policy network.
158
159
159
160
Args:
160
- observations: Tensor with observations
161
+ observations
161
162
hparams: parameters
162
163
action_space: action space
163
164
164
165
Returns:
165
- Tensor with policy and value function output
166
+ Tuple (action logits, value).
166
167
"""
167
- policy_network_lambda = hparams .policy_network
168
- return policy_network_lambda (action_space , hparams , observations )
168
+ if not isinstance (action_space , gym .spaces .Discrete ):
169
+ raise ValueError ("Expecting discrete action space." )
170
+
171
+ model = registry .model (hparams .policy_network )(
172
+ hparams , tf .estimator .ModeKeys .TRAIN
173
+ )
174
+ obs_shape = common_layers .shape_list (observations )
175
+ features = {
176
+ "inputs" : observations ,
177
+ "target_action" : tf .zeros (obs_shape [:2 ] + [action_space .n ]),
178
+ "target_value" : tf .zeros (obs_shape [:2 ])
179
+ }
180
+ with tf .variable_scope (tf .get_variable_scope (), reuse = tf .AUTO_REUSE ):
181
+ (targets , _ ) = model (features )
182
+ return (targets ["target_action" ], targets ["target_value" ])
169
183
170
184
171
185
@registry .register_hparams
172
186
def ppo_pong_ae_base ():
173
187
"""Pong autoencoder base parameters."""
174
188
hparams = ppo_original_params ()
175
189
hparams .learning_rate = 1e-4
176
- hparams .network = dense_bitwise_categorical_fun
190
+ hparams .network = "dense_bitwise_categorical_policy"
177
191
return hparams
178
192
179
193
@@ -225,6 +239,12 @@ def mfrl_original():
225
239
batch_size = 16 ,
226
240
eval_batch_size = 2 ,
227
241
frame_stack_size = 4 ,
242
+ eval_sampling_temps = [0.0 , 0.2 , 0.5 , 0.8 , 1.0 , 2.0 ],
243
+ eval_max_num_noops = 8 ,
244
+ resize_height_factor = 2 ,
245
+ resize_width_factor = 2 ,
246
+ grayscale = 0 ,
247
+ env_timesteps_limit = - 1 ,
228
248
)
229
249
230
250
@@ -234,11 +254,6 @@ def mfrl_base():
234
254
hparams = mfrl_original ()
235
255
hparams .add_hparam ("ppo_epochs_num" , 3000 )
236
256
hparams .add_hparam ("ppo_eval_every_epochs" , 100 )
237
- hparams .add_hparam ("eval_max_num_noops" , 8 )
238
- hparams .add_hparam ("resize_height_factor" , 2 )
239
- hparams .add_hparam ("resize_width_factor" , 2 )
240
- hparams .add_hparam ("grayscale" , 0 )
241
- hparams .add_hparam ("env_timesteps_limit" , - 1 )
242
257
return hparams
243
258
244
259
@@ -250,10 +265,18 @@ def mfrl_tiny():
250
265
return hparams
251
266
252
267
268
+ class DiscretePolicyBase (t2t_model .T2TModel ):
269
+
270
+ @staticmethod
271
+ def _get_num_actions (features ):
272
+ return common_layers .shape_list (features ["target_action" ])[2 ]
273
+
274
+
253
275
NetworkOutput = collections .namedtuple (
254
276
"NetworkOutput" , "policy, value, action_postprocessing" )
255
277
256
278
279
+ # TODO(koz4k): Translate it to T2TModel or remove.
257
280
def feed_forward_gaussian_fun (action_space , config , observations ):
258
281
"""Feed-forward Gaussian."""
259
282
if not isinstance (action_space , gym .spaces .box .Box ):
@@ -303,36 +326,40 @@ def clip_logits(logits, config):
303
326
return logits
304
327
305
328
306
- def feed_forward_categorical_fun (action_space , config , observations ):
329
+ @registry .register_model
330
+ class FeedForwardCategoricalPolicy (DiscretePolicyBase ):
307
331
"""Feed-forward categorical."""
308
- if not isinstance ( action_space , gym . spaces . Discrete ):
309
- raise ValueError ( "Expecting discrete action space." )
310
- flat_observations = tf . reshape ( observations , [
311
- tf .shape (observations )[ 0 ], tf . shape ( observations )[ 1 ],
312
- functools . reduce ( operator . mul , observations .shape . as_list ()[ 2 : ], 1 )])
313
- with tf . variable_scope ( "network_parameters" ):
332
+
333
+ def body ( self , features ):
334
+ observations = features [ "inputs" ]
335
+ flat_observations = tf .reshape (observations , [
336
+ tf .shape ( observations )[ 0 ], tf . shape ( observations )[ 1 ],
337
+ functools . reduce ( operator . mul , observations . shape . as_list ()[ 2 :], 1 )])
314
338
with tf .variable_scope ("policy" ):
315
339
x = flat_observations
316
- for size in config .policy_layers :
340
+ for size in self . hparams .policy_layers :
317
341
x = tf .contrib .layers .fully_connected (x , size , tf .nn .relu )
318
- logits = tf .contrib .layers .fully_connected (x , action_space .n ,
319
- activation_fn = None )
342
+ logits = tf .contrib .layers .fully_connected (
343
+ x , self ._get_num_actions (features ), activation_fn = None
344
+ )
320
345
with tf .variable_scope ("value" ):
321
346
x = flat_observations
322
- for size in config .value_layers :
347
+ for size in self . hparams .value_layers :
323
348
x = tf .contrib .layers .fully_connected (x , size , tf .nn .relu )
324
349
value = tf .contrib .layers .fully_connected (x , 1 , None )[..., 0 ]
325
- logits = clip_logits (logits , config )
326
- policy = tfp .distributions .Categorical (logits = logits )
327
- return NetworkOutput (policy , value , lambda a : a )
350
+ logits = clip_logits (logits , self .hparams )
351
+ return {"target_action" : logits , "target_value" : value }
328
352
329
353
330
- def feed_forward_cnn_small_categorical_fun (action_space , config , observations ):
354
+ @registry .register_model
355
+ class FeedForwardCnnSmallCategoricalPolicy (DiscretePolicyBase ):
331
356
"""Small cnn network with categorical output."""
332
- obs_shape = common_layers .shape_list (observations )
333
- x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
334
- with tf .variable_scope ("network_parameters" ):
335
- dropout = getattr (config , "dropout_ppo" , 0.0 )
357
+
358
+ def body (self , features ):
359
+ observations = features ["inputs" ]
360
+ obs_shape = common_layers .shape_list (observations )
361
+ x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
362
+ dropout = getattr (self .hparams , "dropout_ppo" , 0.0 )
336
363
with tf .variable_scope ("feed_forward_cnn_small" ):
337
364
x = tf .to_float (x ) / 255.0
338
365
x = tf .contrib .layers .conv2d (x , 32 , [5 , 5 ], [2 , 2 ],
@@ -346,23 +373,25 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
346
373
flat_x = tf .nn .dropout (flat_x , keep_prob = 1.0 - dropout )
347
374
x = tf .contrib .layers .fully_connected (flat_x , 128 , tf .nn .relu )
348
375
349
- logits = tf .contrib .layers .fully_connected (x , action_space .n ,
350
- activation_fn = None )
351
- logits = clip_logits (logits , config )
376
+ logits = tf .contrib .layers .fully_connected (
377
+ x , self ._get_num_actions (features ), activation_fn = None
378
+ )
379
+ logits = clip_logits (logits , self .hparams )
352
380
353
381
value = tf .contrib .layers .fully_connected (
354
382
x , 1 , activation_fn = None )[..., 0 ]
355
- policy = tfp .distributions .Categorical (logits = logits )
356
- return NetworkOutput (policy , value , lambda a : a )
383
+ return {"target_action" : logits , "target_value" : value }
357
384
358
385
359
- def feed_forward_cnn_small_categorical_fun_new (
360
- action_space , config , observations ):
386
+ @ registry . register_model
387
+ class FeedForwardCnnSmallCategoricalPolicyNew ( DiscretePolicyBase ):
361
388
"""Small cnn network with categorical output."""
362
- obs_shape = common_layers .shape_list (observations )
363
- x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
364
- with tf .variable_scope ("network_parameters" ):
365
- dropout = getattr (config , "dropout_ppo" , 0.0 )
389
+
390
+ def body (self , features ):
391
+ observations = features ["inputs" ]
392
+ obs_shape = common_layers .shape_list (observations )
393
+ x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
394
+ dropout = getattr (self .hparams , "dropout_ppo" , 0.0 )
366
395
with tf .variable_scope ("feed_forward_cnn_small" ):
367
396
x = tf .to_float (x ) / 255.0
368
397
x = tf .nn .dropout (x , keep_prob = 1.0 - dropout )
@@ -384,22 +413,23 @@ def feed_forward_cnn_small_categorical_fun_new(
384
413
flat_x = tf .nn .dropout (flat_x , keep_prob = 1.0 - dropout )
385
414
x = tf .layers .dense (flat_x , 128 , activation = tf .nn .relu , name = "dense1" )
386
415
387
- logits = tf .layers .dense (x , action_space .n , name = "dense2" )
388
- logits = clip_logits (logits , config )
416
+ logits = tf .layers .dense (
417
+ x , self ._get_num_actions (features ), name = "dense2"
418
+ )
419
+ logits = clip_logits (logits , self .hparams )
389
420
390
421
value = tf .layers .dense (x , 1 , name = "value" )[..., 0 ]
391
- policy = tfp . distributions . Categorical ( logits = logits )
422
+ return { "target_action" : logits , "target_value" : value }
392
423
393
- return NetworkOutput (policy , value , lambda a : a )
394
424
395
-
396
- def dense_bitwise_categorical_fun ( action_space , config , observations ):
425
+ @ registry . register_model
426
+ class DenseBitwiseCategoricalPolicy ( DiscretePolicyBase ):
397
427
"""Dense network with bitwise input and categorical output."""
398
- del config
399
- obs_shape = common_layers .shape_list (observations )
400
- x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
401
428
402
- with tf .variable_scope ("network_parameters" ):
429
+ def body (self , features ):
430
+ observations = features ["inputs" ]
431
+ obs_shape = common_layers .shape_list (observations )
432
+ x = tf .reshape (observations , [- 1 ] + obs_shape [2 :])
403
433
with tf .variable_scope ("dense_bitwise" ):
404
434
x = discretization .int_to_bit_embed (x , 8 , 32 )
405
435
flat_x = tf .reshape (
@@ -409,22 +439,29 @@ def dense_bitwise_categorical_fun(action_space, config, observations):
409
439
x = tf .contrib .layers .fully_connected (flat_x , 256 , tf .nn .relu )
410
440
x = tf .contrib .layers .fully_connected (flat_x , 128 , tf .nn .relu )
411
441
412
- logits = tf .contrib .layers .fully_connected (x , action_space .n ,
413
- activation_fn = None )
442
+ logits = tf .contrib .layers .fully_connected (
443
+ x , self ._get_num_actions (features ), activation_fn = None
444
+ )
414
445
415
446
value = tf .contrib .layers .fully_connected (
416
447
x , 1 , activation_fn = None )[..., 0 ]
417
- policy = tfp .distributions .Categorical (logits = logits )
418
448
419
- return NetworkOutput ( policy , value , lambda a : a )
449
+ return { "target_action" : logits , "target_value" : value }
420
450
421
451
422
- def random_policy_fun (action_space , unused_config , observations ):
452
+ @registry .register_model
453
+ class RandomPolicy (DiscretePolicyBase ):
423
454
"""Random policy with categorical output."""
424
- obs_shape = observations .shape .as_list ()
425
- with tf .variable_scope ("network_parameters" ):
455
+
456
+ def body (self , features ):
457
+ observations = features ["inputs" ]
458
+ obs_shape = observations .shape .as_list ()
459
+ # Just so Saver doesn't complain because of no variables.
460
+ tf .get_variable ("dummy_var" , initializer = 0.0 )
461
+ num_actions = self ._get_num_actions (features )
462
+ logits = tf .constant (
463
+ 1. / float (num_actions ),
464
+ shape = (obs_shape [:2 ] + [num_actions ])
465
+ )
426
466
value = tf .zeros (obs_shape [:2 ])
427
- policy = tfp .distributions .Categorical (
428
- probs = [[[1. / float (action_space .n )] * action_space .n ] *
429
- (obs_shape [0 ] * obs_shape [1 ])])
430
- return NetworkOutput (policy , value , lambda a : a )
467
+ return {"target_action" : logits , "target_value" : value }
0 commit comments