Skip to content

Commit 5acd3e6

Browse files
authored
[rllib] Fix torch TD error, IMPALA LR updates (#9477)
* update * add test * lint * fix super call * speed es test up
1 parent ea4797b commit 5acd3e6

File tree

7 files changed

+97
-54
lines changed

7 files changed

+97
-54
lines changed

rllib/agents/dqn/dqn_torch_policy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self,
5252
"mean_q": torch.mean(q_t_selected),
5353
"min_q": torch.min(q_t_selected),
5454
"max_q": torch.max(q_t_selected),
55-
"td_error": self.td_error,
5655
"mean_td_error": torch.mean(self.td_error),
5756
}
5857

@@ -250,10 +249,7 @@ def compute_q_values(policy, model, obs, explore, is_training=False):
250249

251250
def grad_process_and_td_error_fn(policy, optimizer, loss):
252251
# Clip grads if configured.
253-
info = apply_grad_clipping(policy, optimizer, loss)
254-
# Add td-error to info dict.
255-
info["td_error"] = policy.q_loss.td_error
256-
return info
252+
return apply_grad_clipping(policy, optimizer, loss)
257253

258254

259255
def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist):
@@ -270,6 +266,7 @@ def extra_action_out_fn(policy, input_dict, state_batches, model, action_dist):
270266
postprocess_fn=postprocess_nstep_and_prio,
271267
optimizer_fn=adam_optimizer,
272268
extra_grad_process_fn=grad_process_and_td_error_fn,
269+
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
273270
extra_action_out_fn=extra_action_out_fn,
274271
before_init=setup_early_mixins,
275272
after_init=after_init,

rllib/agents/dqn/simple_q_torch_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,5 @@ def setup_late_mixins(policy, obs_space, action_space, config):
9696
make_model_and_action_dist=build_q_model_and_distribution,
9797
mixins=[TargetNetworkMixin],
9898
action_distribution_fn=get_distribution_inputs_and_class,
99-
stats_fn=lambda policy, config: {"td_error": policy.td_error},
99+
extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error},
100100
)

rllib/agents/es/tests/test_es.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
class TestES(unittest.TestCase):
1010
def test_es_compilation(self):
1111
"""Test whether an ESTrainer can be built on all frameworks."""
12-
ray.init()
12+
ray.init(num_cpus=2)
1313
config = es.DEFAULT_CONFIG.copy()
1414
# Keep it simple.
1515
config["model"]["fcnet_hiddens"] = [10]
1616
config["model"]["fcnet_activation"] = None
1717
config["noise_size"] = 2500000
18+
config["num_workers"] = 1
19+
config["episodes_per_batch"] = 10
20+
config["train_batch_size"] = 100
1821

19-
num_iterations = 2
22+
num_iterations = 1
2023

2124
for _ in framework_iterator(config):
2225
plain_config = config.copy()

rllib/agents/impala/impala.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ def __call__(self, item):
194194
metrics = _get_shared_metrics()
195195
metrics.counters["num_weight_broadcasts"] += 1
196196
actor.set_weights.remote(self.weights, _get_global_vars())
197+
# Also update global vars of the local worker.
198+
self.workers.local_worker().set_global_vars(_get_global_vars())
197199

198200

199201
def record_steps_trained(item):

rllib/agents/impala/tests/test_impala.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@ def setUpClass(cls) -> None:
1818
def tearDownClass(cls) -> None:
1919
ray.shutdown()
2020

21+
def test_impala_lr_schedule(self):
22+
config = impala.DEFAULT_CONFIG.copy()
23+
config["lr_schedule"] = [
24+
[0, 0.0005],
25+
[10000, 0.000001],
26+
]
27+
local_cfg = config.copy()
28+
trainer = impala.ImpalaTrainer(config=local_cfg, env="CartPole-v0")
29+
30+
def get_lr(result):
31+
return result["info"]["learner"]["default_policy"]["cur_lr"]
32+
33+
try:
34+
r1 = trainer.train()
35+
r2 = trainer.train()
36+
assert get_lr(r2) < get_lr(r1), (r1, r2)
37+
finally:
38+
trainer.stop()
39+
2140
def test_impala_compilation(self):
2241
"""Test whether an ImpalaTrainer can be built with both frameworks."""
2342
config = impala.DEFAULT_CONFIG.copy()

rllib/policy/torch_policy.py

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,22 @@ class TorchPolicy(Policy):
4040
"""
4141

4242
@DeveloperAPI
43-
def __init__(self,
44-
observation_space: gym.spaces.Space,
45-
action_space: gym.spaces.Space,
46-
config: TrainerConfigDict,
47-
*,
48-
model: ModelV2,
49-
loss: Callable[
50-
[Policy, ModelV2, type, SampleBatch], TensorType],
51-
action_distribution_class: TorchDistributionWrapper,
52-
action_sampler_fn: Callable[
53-
[TensorType, List[TensorType]], Tuple[
54-
TensorType, TensorType]] = None,
55-
action_distribution_fn: Optional[Callable[
56-
[Policy, ModelV2, TensorType, TensorType, TensorType],
57-
Tuple[TensorType, type, List[TensorType]]]] = None,
58-
max_seq_len: int = 20,
59-
get_batch_divisibility_req: Optional[int] = None):
43+
def __init__(
44+
self,
45+
observation_space: gym.spaces.Space,
46+
action_space: gym.spaces.Space,
47+
config: TrainerConfigDict,
48+
*,
49+
model: ModelV2,
50+
loss: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
51+
action_distribution_class: TorchDistributionWrapper,
52+
action_sampler_fn: Callable[[TensorType, List[TensorType]], Tuple[
53+
TensorType, TensorType]] = None,
54+
action_distribution_fn: Optional[Callable[[
55+
Policy, ModelV2, TensorType, TensorType, TensorType
56+
], Tuple[TensorType, type, List[TensorType]]]] = None,
57+
max_seq_len: int = 20,
58+
get_batch_divisibility_req: Optional[int] = None):
6059
"""Build a policy from policy and loss torch modules.
6160
6261
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
@@ -165,8 +164,8 @@ def compute_actions(
165164
extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
166165
extra_fetches[SampleBatch.ACTION_LOGP] = logp
167166

168-
return convert_to_non_torch_type(
169-
(actions, state_out, extra_fetches))
167+
return convert_to_non_torch_type((actions, state_out,
168+
extra_fetches))
170169

171170
@override(Policy)
172171
def compute_actions_from_trajectories(
@@ -183,8 +182,9 @@ def compute_actions_from_trajectories(
183182

184183
with torch.no_grad():
185184
# Create a view and pass that to Model as `input_dict`.
186-
input_dict = self._lazy_tensor_dict(get_trajectory_view(
187-
self.model, trajectories, is_training=False))
185+
input_dict = self._lazy_tensor_dict(
186+
get_trajectory_view(
187+
self.model, trajectories, is_training=False))
188188
# TODO: (sven) support RNNs w/ fast sampling.
189189
state_batches = []
190190
seq_lens = None
@@ -232,8 +232,8 @@ def _compute_action_helper(self, input_dict, state_batches, seq_lens,
232232
is_training=False)
233233
else:
234234
dist_class = self.dist_class
235-
dist_inputs, state_out = self.model(
236-
input_dict, state_batches, seq_lens)
235+
dist_inputs, state_out = self.model(input_dict, state_batches,
236+
seq_lens)
237237

238238
if not (isinstance(dist_class, functools.partial)
239239
or issubclass(dist_class, TorchDistributionWrapper)):
@@ -270,10 +270,10 @@ def compute_log_likelihoods(
270270
actions: Union[List[TensorType], TensorType],
271271
obs_batch: Union[List[TensorType], TensorType],
272272
state_batches: Optional[List[TensorType]] = None,
273-
prev_action_batch: Optional[
274-
Union[List[TensorType], TensorType]] = None,
275-
prev_reward_batch: Optional[
276-
Union[List[TensorType], TensorType]] = None) -> TensorType:
273+
prev_action_batch: Optional[Union[List[TensorType],
274+
TensorType]] = None,
275+
prev_reward_batch: Optional[Union[List[
276+
TensorType], TensorType]] = None) -> TensorType:
277277

278278
if self.action_sampler_fn and self.action_distribution_fn is None:
279279
raise ValueError("Cannot compute log-prob/likelihood w/o an "
@@ -314,8 +314,8 @@ def compute_log_likelihoods(
314314

315315
@override(Policy)
316316
@DeveloperAPI
317-
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
318-
str, TensorType]:
317+
def learn_on_batch(
318+
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
319319
# Get batch ready for RNNs, if applicable.
320320
pad_batch_to_sequences_of_same_size(
321321
postprocessed_batch,
@@ -331,6 +331,7 @@ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
331331
loss_out = self.model.custom_loss(loss_out, train_batch)
332332
assert len(loss_out) == len(self._optimizers)
333333
# assert not any(torch.isnan(l) for l in loss_out)
334+
fetches = self.extra_compute_grad_fetches()
334335

335336
# Loop through all optimizers.
336337
grad_info = {"allreduce_latency": 0.0}
@@ -370,7 +371,7 @@ def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[
370371

371372
grad_info["allreduce_latency"] /= len(self._optimizers)
372373
grad_info.update(self.extra_grad_info(train_batch))
373-
return {LEARNER_STATS_KEY: grad_info}
374+
return dict(fetches, **{LEARNER_STATS_KEY: grad_info})
374375

375376
@override(Policy)
376377
@DeveloperAPI
@@ -380,6 +381,7 @@ def compute_gradients(self,
380381
loss_out = force_list(
381382
self._loss(self, self.model, self.dist_class, train_batch))
382383
assert len(loss_out) == len(self._optimizers)
384+
fetches = self.extra_compute_grad_fetches()
383385

384386
grad_process_info = {}
385387
grads = []
@@ -399,7 +401,7 @@ def compute_gradients(self,
399401

400402
grad_info = self.extra_grad_info(train_batch)
401403
grad_info.update(grad_process_info)
402-
return grads, {LEARNER_STATS_KEY: grad_info}
404+
return grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
403405

404406
@override(Policy)
405407
@DeveloperAPI
@@ -466,10 +468,8 @@ def set_state(self, state: object) -> None:
466468
super().set_state(state)
467469

468470
@DeveloperAPI
469-
def extra_grad_process(
470-
self,
471-
optimizer: "torch.optim.Optimizer",
472-
loss: TensorType):
471+
def extra_grad_process(self, optimizer: "torch.optim.Optimizer",
472+
loss: TensorType):
473473
"""Called after each optimizer.zero_grad() + loss.backward() call.
474474
475475
Called for each self._optimizers/loss-value pair.
@@ -486,12 +486,20 @@ def extra_grad_process(
486486
"""
487487
return {}
488488

489+
@DeveloperAPI
490+
def extra_compute_grad_fetches(self) -> Dict[str, any]:
491+
"""Extra values to fetch and return from compute_gradients().
492+
493+
Returns:
494+
Dict[str, any]: Extra fetch dict to be added to the fetch dict
495+
of the compute_gradients call.
496+
"""
497+
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
498+
489499
@DeveloperAPI
490500
def extra_action_out(
491-
self,
492-
input_dict: Dict[str, TensorType],
493-
state_batches: List[TensorType],
494-
model: TorchModelV2,
501+
self, input_dict: Dict[str, TensorType],
502+
state_batches: List[TensorType], model: TorchModelV2,
495503
action_dist: TorchDistributionWrapper) -> Dict[str, TensorType]:
496504
"""Returns dict of extra info to include in experience batch.
497505
@@ -509,8 +517,8 @@ def extra_action_out(
509517
return {}
510518

511519
@DeveloperAPI
512-
def extra_grad_info(self, train_batch: SampleBatch) -> Dict[
513-
str, TensorType]:
520+
def extra_grad_info(self,
521+
train_batch: SampleBatch) -> Dict[str, TensorType]:
514522
"""Return dict of extra grad info.
515523
516524
Args:
@@ -524,8 +532,9 @@ def extra_grad_info(self, train_batch: SampleBatch) -> Dict[
524532
return {}
525533

526534
@DeveloperAPI
527-
def optimizer(self) -> Union[
528-
List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
535+
def optimizer(
536+
self
537+
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
529538
"""Custom the local PyTorch optimizer(s) to use.
530539
531540
Returns:
@@ -560,8 +569,8 @@ def import_model_from_h5(self, import_file: str) -> None:
560569

561570
def _lazy_tensor_dict(self, postprocessed_batch):
562571
train_batch = UsageTrackingDict(postprocessed_batch)
563-
train_batch.set_get_interceptor(functools.partial(
564-
convert_to_torch_tensor, device=self.device))
572+
train_batch.set_get_interceptor(
573+
functools.partial(convert_to_torch_tensor, device=self.device))
565574
return train_batch
566575

567576

rllib/policy/torch_policy_template.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ray.rllib.policy.policy import Policy
1+
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
22
from ray.rllib.policy.torch_policy import TorchPolicy
33
from ray.rllib.models.catalog import ModelCatalog
44
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
@@ -19,6 +19,7 @@ def build_torch_policy(name,
1919
postprocess_fn=None,
2020
extra_action_out_fn=None,
2121
extra_grad_process_fn=None,
22+
extra_learn_fetches_fn=None,
2223
optimizer_fn=None,
2324
validate_spaces=None,
2425
before_init=None,
@@ -47,6 +48,8 @@ def build_torch_policy(name,
4748
returns a dict of extra values to include in experiences.
4849
extra_grad_process_fn (Optional[callable]): Optional callable that is
4950
called after gradients are computed and returns processing info.
51+
extra_learn_fetches_fn (func): optional function that returns a dict of
52+
extra values to fetch from the policy after loss evaluation.
5053
optimizer_fn (Optional[callable]): Optional callable that returns a
5154
torch optimizer given the policy and config.
5255
validate_spaces (Optional[callable]): Optional callable that takes the
@@ -179,6 +182,16 @@ def extra_grad_process(self, optimizer, loss):
179182
else:
180183
return TorchPolicy.extra_grad_process(self, optimizer, loss)
181184

185+
@override(TorchPolicy)
186+
def extra_compute_grad_fetches(self):
187+
if extra_learn_fetches_fn:
188+
fetches = convert_to_non_torch_type(
189+
extra_learn_fetches_fn(self))
190+
# Auto-add empty learner stats dict if needed.
191+
return dict({LEARNER_STATS_KEY: {}}, **fetches)
192+
else:
193+
return TorchPolicy.extra_compute_grad_fetches(self)
194+
182195
@override(TorchPolicy)
183196
def apply_gradients(self, gradients):
184197
if apply_gradients_fn:

0 commit comments

Comments
 (0)