Skip to content

Commit

Permalink
[RLlib] Add grad_clip config option to MARWIL and stabilize grad clip…
Browse files Browse the repository at this point in the history
…ping against inf global_norms. (ray-project#13634)
  • Loading branch information
sven1977 authored Jan 22, 2021
1 parent da59283 commit d629292
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
2 changes: 2 additions & 0 deletions rllib/agents/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"beta": 1.0,
# Balancing value estimation loss and policy optimization loss.
"vf_coeff": 1.0,
# If specified, clip the global norm of gradients by this amount.
"grad_clip": None,
# Whether to calculate cumulative rewards.
"postprocess_inputs": True,
# Whether to rollout "complete_episodes" or "truncate_episodes".
Expand Down
4 changes: 3 additions & 1 deletion rllib/agents/marwil/marwil_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import ray
from ray.rllib.agents.ppo.ppo_tf_policy import compute_and_clip_gradients
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(self, policy, value_estimates, action_dist, actions,

# Exponentially weighted advantages.
c = tf.math.sqrt(policy._moving_average_sqd_adv_norm)
exp_advs = tf.math.exp(beta * (adv / c))
exp_advs = tf.math.exp(beta * (adv / (1e-8 + c)))
# Static graph.
else:
update_adv_norm = tf1.assign_add(
Expand Down Expand Up @@ -200,4 +201,5 @@ def setup_mixins(policy, obs_space, action_space, config):
stats_fn=stats,
postprocess_fn=postprocess_advantages,
before_loss_init=setup_mixins,
gradients_fn=compute_and_clip_gradients,
mixins=[ValueNetworkMixin])
3 changes: 2 additions & 1 deletion rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance
from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance

torch, _ = try_import_torch()

Expand Down Expand Up @@ -98,5 +98,6 @@ def setup_mixins(policy, obs_space, action_space, config):
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
postprocess_fn=postprocess_advantages,
extra_grad_process_fn=apply_grad_clipping,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin])
10 changes: 8 additions & 2 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,15 @@ def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,

# Clip by global norm, if necessary.
if policy.config["grad_clip"] is not None:
# Defuse inf gradients (due to super large losses).
grads = [g for (g, v) in grads_and_vars]
policy.grads, _ = tf.clip_by_global_norm(grads,
policy.config["grad_clip"])
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
# If the global_norm is inf -> All grads will be NaN. Stabilize this
# here by setting them to 0.0. This will simply ignore destructive loss
# calculations.
policy.grads = [
tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads
]
clipped_grads_and_vars = list(zip(policy.grads, variables))
return clipped_grads_and_vars
else:
Expand Down

0 comments on commit d629292

Please sign in to comment.