@@ -287,12 +287,12 @@ def new_episode():
287
287
288
288
# Do batched policy eval
289
289
eval_results = _do_policy_eval (tf_sess , to_eval , policies ,
290
- active_episodes , clip_actions )
290
+ active_episodes )
291
291
292
292
# Process results and update episode state
293
293
actions_to_send = _process_policy_eval_results (
294
294
to_eval , eval_results , active_episodes , active_envs ,
295
- off_policy_actions )
295
+ off_policy_actions , policies , clip_actions )
296
296
297
297
# Return computed actions to ready envs. We also send to envs that have
298
298
# taken off-policy actions; those envs are free to ignore the action.
@@ -448,7 +448,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
448
448
return active_envs , to_eval , outputs
449
449
450
450
451
- def _do_policy_eval (tf_sess , to_eval , policies , active_episodes , clip_actions ):
451
+ def _do_policy_eval (tf_sess , to_eval , policies , active_episodes ):
452
452
"""Call compute actions on observation batches to get next actions.
453
453
454
454
Returns:
@@ -483,18 +483,12 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes, clip_actions):
483
483
for k , v in pending_fetches .items ():
484
484
eval_results [k ] = builder .get (v )
485
485
486
- if clip_actions :
487
- for policy_id , results in eval_results .items ():
488
- policy = _get_or_raise (policies , policy_id )
489
- actions , rnn_out_cols , pi_info_cols = results
490
- eval_results [policy_id ] = (_clip_actions (
491
- actions , policy .action_space ), rnn_out_cols , pi_info_cols )
492
-
493
486
return eval_results
494
487
495
488
496
489
def _process_policy_eval_results (to_eval , eval_results , active_episodes ,
497
- active_envs , off_policy_actions ):
490
+ active_envs , off_policy_actions , policies ,
491
+ clip_actions ):
498
492
"""Process the output of policy neural network evaluation.
499
493
500
494
Records policy evaluation results into the given episode objects and
@@ -521,10 +515,15 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
521
515
pi_info_cols ["state_out_{}" .format (f_i )] = column
522
516
# Save output rows
523
517
actions = _unbatch_tuple_actions (actions )
518
+ policy = _get_or_raise (policies , policy_id )
524
519
for i , action in enumerate (actions ):
525
520
env_id = eval_data [i ].env_id
526
521
agent_id = eval_data [i ].agent_id
527
- actions_to_send [env_id ][agent_id ] = action
522
+ if clip_actions :
523
+ actions_to_send [env_id ][agent_id ] = _clip_actions (
524
+ action , policy .action_space )
525
+ else :
526
+ actions_to_send [env_id ][agent_id ] = action
528
527
episode = active_episodes [env_id ]
529
528
episode ._set_rnn_state (agent_id , [c [i ] for c in rnn_out_cols ])
530
529
episode ._set_last_pi_info (
@@ -562,7 +561,7 @@ def _clip_actions(actions, space):
562
561
"""Called to clip actions to the specified range of this policy.
563
562
564
563
Arguments:
565
- actions: Batch of actions or TupleActions .
564
+ actions: Single action .
566
565
space: Action space the actions should be present in.
567
566
568
567
Returns:
@@ -572,13 +571,13 @@ def _clip_actions(actions, space):
572
571
if isinstance (space , gym .spaces .Box ):
573
572
return np .clip (actions , space .low , space .high )
574
573
elif isinstance (space , gym .spaces .Tuple ):
575
- if not isinstance ( actions , TupleActions ):
574
+ if type ( actions ) not in ( tuple , list ):
576
575
raise ValueError ("Expected tuple space for actions {}: {}" .format (
577
576
actions , space ))
578
577
out = []
579
- for a , s in zip (actions . batches , space .spaces ):
578
+ for a , s in zip (actions , space .spaces ):
580
579
out .append (_clip_actions (a , s ))
581
- return TupleActions ( out )
580
+ return out
582
581
else :
583
582
return actions
584
583
0 commit comments