@@ -40,23 +40,22 @@ class TorchPolicy(Policy):
40
40
"""
41
41
42
42
@DeveloperAPI
43
- def __init__ (self ,
44
- observation_space : gym .spaces .Space ,
45
- action_space : gym .spaces .Space ,
46
- config : TrainerConfigDict ,
47
- * ,
48
- model : ModelV2 ,
49
- loss : Callable [
50
- [Policy , ModelV2 , type , SampleBatch ], TensorType ],
51
- action_distribution_class : TorchDistributionWrapper ,
52
- action_sampler_fn : Callable [
53
- [TensorType , List [TensorType ]], Tuple [
54
- TensorType , TensorType ]] = None ,
55
- action_distribution_fn : Optional [Callable [
56
- [Policy , ModelV2 , TensorType , TensorType , TensorType ],
57
- Tuple [TensorType , type , List [TensorType ]]]] = None ,
58
- max_seq_len : int = 20 ,
59
- get_batch_divisibility_req : Optional [int ] = None ):
43
+ def __init__ (
44
+ self ,
45
+ observation_space : gym .spaces .Space ,
46
+ action_space : gym .spaces .Space ,
47
+ config : TrainerConfigDict ,
48
+ * ,
49
+ model : ModelV2 ,
50
+ loss : Callable [[Policy , ModelV2 , type , SampleBatch ], TensorType ],
51
+ action_distribution_class : TorchDistributionWrapper ,
52
+ action_sampler_fn : Callable [[TensorType , List [TensorType ]], Tuple [
53
+ TensorType , TensorType ]] = None ,
54
+ action_distribution_fn : Optional [Callable [[
55
+ Policy , ModelV2 , TensorType , TensorType , TensorType
56
+ ], Tuple [TensorType , type , List [TensorType ]]]] = None ,
57
+ max_seq_len : int = 20 ,
58
+ get_batch_divisibility_req : Optional [int ] = None ):
60
59
"""Build a policy from policy and loss torch modules.
61
60
62
61
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
@@ -165,8 +164,8 @@ def compute_actions(
165
164
extra_fetches [SampleBatch .ACTION_PROB ] = np .exp (logp )
166
165
extra_fetches [SampleBatch .ACTION_LOGP ] = logp
167
166
168
- return convert_to_non_torch_type (
169
- ( actions , state_out , extra_fetches ))
167
+ return convert_to_non_torch_type (( actions , state_out ,
168
+ extra_fetches ))
170
169
171
170
@override (Policy )
172
171
def compute_actions_from_trajectories (
@@ -183,8 +182,9 @@ def compute_actions_from_trajectories(
183
182
184
183
with torch .no_grad ():
185
184
# Create a view and pass that to Model as `input_dict`.
186
- input_dict = self ._lazy_tensor_dict (get_trajectory_view (
187
- self .model , trajectories , is_training = False ))
185
+ input_dict = self ._lazy_tensor_dict (
186
+ get_trajectory_view (
187
+ self .model , trajectories , is_training = False ))
188
188
# TODO: (sven) support RNNs w/ fast sampling.
189
189
state_batches = []
190
190
seq_lens = None
@@ -232,8 +232,8 @@ def _compute_action_helper(self, input_dict, state_batches, seq_lens,
232
232
is_training = False )
233
233
else :
234
234
dist_class = self .dist_class
235
- dist_inputs , state_out = self .model (
236
- input_dict , state_batches , seq_lens )
235
+ dist_inputs , state_out = self .model (input_dict , state_batches ,
236
+ seq_lens )
237
237
238
238
if not (isinstance (dist_class , functools .partial )
239
239
or issubclass (dist_class , TorchDistributionWrapper )):
@@ -270,10 +270,10 @@ def compute_log_likelihoods(
270
270
actions : Union [List [TensorType ], TensorType ],
271
271
obs_batch : Union [List [TensorType ], TensorType ],
272
272
state_batches : Optional [List [TensorType ]] = None ,
273
- prev_action_batch : Optional [
274
- Union [ List [ TensorType ], TensorType ]] = None ,
275
- prev_reward_batch : Optional [
276
- Union [ List [ TensorType ], TensorType ]] = None ) -> TensorType :
273
+ prev_action_batch : Optional [Union [ List [ TensorType ],
274
+ TensorType ]] = None ,
275
+ prev_reward_batch : Optional [Union [ List [
276
+ TensorType ], TensorType ]] = None ) -> TensorType :
277
277
278
278
if self .action_sampler_fn and self .action_distribution_fn is None :
279
279
raise ValueError ("Cannot compute log-prob/likelihood w/o an "
@@ -314,8 +314,8 @@ def compute_log_likelihoods(
314
314
315
315
@override (Policy )
316
316
@DeveloperAPI
317
- def learn_on_batch (self , postprocessed_batch : SampleBatch ) -> Dict [
318
- str , TensorType ]:
317
+ def learn_on_batch (
318
+ self , postprocessed_batch : SampleBatch ) -> Dict [ str , TensorType ]:
319
319
# Get batch ready for RNNs, if applicable.
320
320
pad_batch_to_sequences_of_same_size (
321
321
postprocessed_batch ,
@@ -331,6 +331,7 @@ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
331
331
loss_out = self .model .custom_loss (loss_out , train_batch )
332
332
assert len (loss_out ) == len (self ._optimizers )
333
333
# assert not any(torch.isnan(l) for l in loss_out)
334
+ fetches = self .extra_compute_grad_fetches ()
334
335
335
336
# Loop through all optimizers.
336
337
grad_info = {"allreduce_latency" : 0.0 }
@@ -370,7 +371,7 @@ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
370
371
371
372
grad_info ["allreduce_latency" ] /= len (self ._optimizers )
372
373
grad_info .update (self .extra_grad_info (train_batch ))
373
- return {LEARNER_STATS_KEY : grad_info }
374
+ return dict ( fetches , ** {LEARNER_STATS_KEY : grad_info })
374
375
375
376
@override (Policy )
376
377
@DeveloperAPI
@@ -380,6 +381,7 @@ def compute_gradients(self,
380
381
loss_out = force_list (
381
382
self ._loss (self , self .model , self .dist_class , train_batch ))
382
383
assert len (loss_out ) == len (self ._optimizers )
384
+ fetches = self .extra_compute_grad_fetches ()
383
385
384
386
grad_process_info = {}
385
387
grads = []
@@ -399,7 +401,7 @@ def compute_gradients(self,
399
401
400
402
grad_info = self .extra_grad_info (train_batch )
401
403
grad_info .update (grad_process_info )
402
- return grads , {LEARNER_STATS_KEY : grad_info }
404
+ return grads , dict ( fetches , ** {LEARNER_STATS_KEY : grad_info })
403
405
404
406
@override (Policy )
405
407
@DeveloperAPI
@@ -466,10 +468,8 @@ def set_state(self, state: object) -> None:
466
468
super ().set_state (state )
467
469
468
470
@DeveloperAPI
469
- def extra_grad_process (
470
- self ,
471
- optimizer : "torch.optim.Optimizer" ,
472
- loss : TensorType ):
471
+ def extra_grad_process (self , optimizer : "torch.optim.Optimizer" ,
472
+ loss : TensorType ):
473
473
"""Called after each optimizer.zero_grad() + loss.backward() call.
474
474
475
475
Called for each self._optimizers/loss-value pair.
@@ -486,12 +486,20 @@ def extra_grad_process(
486
486
"""
487
487
return {}
488
488
489
+ @DeveloperAPI
490
+ def extra_compute_grad_fetches (self ) -> Dict [str , any ]:
491
+ """Extra values to fetch and return from compute_gradients().
492
+
493
+ Returns:
494
+ Dict[str, any]: Extra fetch dict to be added to the fetch dict
495
+ of the compute_gradients call.
496
+ """
497
+ return {LEARNER_STATS_KEY : {}} # e.g, stats, td error, etc.
498
+
489
499
@DeveloperAPI
490
500
def extra_action_out (
491
- self ,
492
- input_dict : Dict [str , TensorType ],
493
- state_batches : List [TensorType ],
494
- model : TorchModelV2 ,
501
+ self , input_dict : Dict [str , TensorType ],
502
+ state_batches : List [TensorType ], model : TorchModelV2 ,
495
503
action_dist : TorchDistributionWrapper ) -> Dict [str , TensorType ]:
496
504
"""Returns dict of extra info to include in experience batch.
497
505
@@ -509,8 +517,8 @@ def extra_action_out(
509
517
return {}
510
518
511
519
@DeveloperAPI
512
- def extra_grad_info (self , train_batch : SampleBatch ) -> Dict [
513
- str , TensorType ]:
520
+ def extra_grad_info (self ,
521
+ train_batch : SampleBatch ) -> Dict [ str , TensorType ]:
514
522
"""Return dict of extra grad info.
515
523
516
524
Args:
@@ -524,8 +532,9 @@ def extra_grad_info(self, train_batch: SampleBatch) -> Dict[
524
532
return {}
525
533
526
534
@DeveloperAPI
527
- def optimizer (self ) -> Union [
528
- List ["torch.optim.Optimizer" ], "torch.optim.Optimizer" ]:
535
+ def optimizer (
536
+ self
537
+ ) -> Union [List ["torch.optim.Optimizer" ], "torch.optim.Optimizer" ]:
529
538
"""Custom the local PyTorch optimizer(s) to use.
530
539
531
540
Returns:
@@ -560,8 +569,8 @@ def import_model_from_h5(self, import_file: str) -> None:
560
569
561
570
def _lazy_tensor_dict (self , postprocessed_batch ):
562
571
train_batch = UsageTrackingDict (postprocessed_batch )
563
- train_batch .set_get_interceptor (functools . partial (
564
- convert_to_torch_tensor , device = self .device ))
572
+ train_batch .set_get_interceptor (
573
+ functools . partial ( convert_to_torch_tensor , device = self .device ))
565
574
return train_batch
566
575
567
576
0 commit comments