Skip to content

Commit

Permalink
[RLlib] A2/3C policy sub-classing schema. (ray-project#25078)
Browse files Browse the repository at this point in the history
sven1977 authored May 28, 2022
1 parent 009df65 commit ab6c302
Showing 7 changed files with 335 additions and 319 deletions.
9 changes: 7 additions & 2 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
from typing import Any, Dict, List, Optional, Type, Union

from ray.actor import ActorHandle
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.evaluation.rollout_worker import RolloutWorker
@@ -182,8 +181,14 @@ def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy

return A3CTorchPolicy
elif config["framework"] == "tf":
from ray.rllib.agents.a3c.a3c_tf_policy import A3CStaticGraphTFPolicy

return A3CStaticGraphTFPolicy
else:
return A3CTFPolicy
from ray.rllib.agents.a3c.a3c_tf_policy import A3CEagerTFPolicy

return A3CEagerTFPolicy

def training_iteration(self) -> ResultDict:
# Shortcut.
314 changes: 165 additions & 149 deletions rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,194 @@
"""Note: Keep in sync with changes to VTraceTFPolicy."""
from typing import Optional, Dict
import gym
from typing import Dict, List, Optional, Type, Union

import ray
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_mixins import (
compute_gradients,
EntropyCoeffSchedule,
LearningRateSchedule,
ValueNetworkMixin,
)
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_utils import explained_variance
from ray.rllib.utils.typing import (
TrainerConfigDict,
TensorType,
PolicyID,
AgentID,
LocalOptimizer,
ModelGradients,
TensorType,
TFPolicyV2Type,
)

tf1, tf, tfv = try_import_tf()


# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_a3c_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
"""Construct a A3CTFPolicy inheriting either dynamic or eager base policies.
Args:
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2.
Returns:
A TF Policy to be used with MAMLTrainer.
"""

class A3CTFPolicy(
ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, base
):
def __init__(
self,
obs_space,
action_space,
config,
existing_model=None,
existing_inputs=None,
):
# First thing first, enable eager execution if necessary.
base.enable_eager_execution_if_necessary()

config = dict(ray.rllib.agents.a3c.a3c.A3CConfig().to_dict(), **config)

# Initialize base class.
base.__init__(
self,
obs_space,
action_space,
config,
existing_inputs=existing_inputs,
existing_model=existing_model,
)

ValueNetworkMixin.__init__(self, self.config)
LearningRateSchedule.__init__(
self, self.config["lr"], self.config["lr_schedule"]
)
EntropyCoeffSchedule.__init__(
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
)

# Note: this is a bit ugly, but loss and optimizer initialization must
# happen after all the MixIns are initialized.
self.maybe_initialize_optimizer_and_loss()

@override(base)
def loss(
self,
model: Union[ModelV2, "tf.keras.Model"],
dist_class: Type[TFActionDistribution],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:

model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if self.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
valid_mask = tf.sequence_mask(
train_batch[SampleBatch.SEQ_LENS], max_seq_len
)
valid_mask = tf.reshape(valid_mask, [-1])
else:
valid_mask = tf.ones_like(train_batch[SampleBatch.REWARDS])

log_prob = action_dist.logp(train_batch[SampleBatch.ACTIONS])
vf = model.value_function()

# The "policy gradients" loss
self.pi_loss = -tf.reduce_sum(
tf.boolean_mask(
log_prob * train_batch[Postprocessing.ADVANTAGES], valid_mask
)
)

delta = tf.boolean_mask(
vf - train_batch[Postprocessing.VALUE_TARGETS], valid_mask
)

# Compute a value function loss.
if self.config.get("use_critic", True):
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta))
# Ignore the value function.
else:
self.vf_loss = tf.constant(0.0)

self.entropy_loss = tf.reduce_sum(
tf.boolean_mask(action_dist.entropy(), valid_mask)
)

self.total_loss = (
self.pi_loss
+ self.vf_loss * self.config["vf_loss_coeff"]
- self.entropy_loss * self.entropy_coeff
)

return self.total_loss

@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
"policy_loss": self.pi_loss,
"policy_entropy": self.entropy_loss,
"var_gnorm": tf.linalg.global_norm(
list(self.model.trainable_variables())
),
"vf_loss": self.vf_loss,
}

@override(base)
def grad_stats_fn(
self, train_batch: SampleBatch, grads: ModelGradients
) -> Dict[str, TensorType]:
return {
"grad_gnorm": tf.linalg.global_norm(grads),
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
self.model.value_function(),
),
}

@override(base)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[Episode] = None,
):
sample_batch = super().postprocess_trajectory(sample_batch)
return compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)

@override(base)
def compute_gradients_fn(
self, optimizer: LocalOptimizer, loss: TensorType
) -> ModelGradients:
return compute_gradients(self, optimizer, loss)

return A3CTFPolicy


A3CStaticGraphTFPolicy = get_a3c_tf_policy(DynamicTFPolicyV2)
A3CEagerTFPolicy = get_a3c_tf_policy(EagerTFPolicyV2)


@Deprecated(
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False,
)
def postprocess_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[Episode] = None,
) -> SampleBatch:

return compute_gae_for_sample_batch(
policy, sample_batch, other_agent_batches, episode
)


class A3CLoss:
def __init__(
self,
action_dist: ActionDistribution,
actions: TensorType,
advantages: TensorType,
v_target: TensorType,
vf: TensorType,
valid_mask: TensorType,
vf_loss_coeff: float = 0.5,
entropy_coeff: float = 0.01,
use_critic: bool = True,
):
log_prob = action_dist.logp(actions)

# The "policy gradients" loss
self.pi_loss = -tf.reduce_sum(
tf.boolean_mask(log_prob * advantages, valid_mask)
)

delta = tf.boolean_mask(vf - v_target, valid_mask)

# Compute a value function loss.
if use_critic:
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta))
# Ignore the value function.
else:
self.vf_loss = tf.constant(0.0)

self.entropy = tf.reduce_sum(tf.boolean_mask(action_dist.entropy(), valid_mask))

self.total_loss = (
self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff
)


def actor_critic_loss(
policy: Policy,
model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch,
) -> TensorType:
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
mask = tf.reshape(mask, [-1])
else:
mask = tf.ones_like(train_batch[SampleBatch.REWARDS])
policy.loss = A3CLoss(
action_dist,
train_batch[SampleBatch.ACTIONS],
train_batch[Postprocessing.ADVANTAGES],
train_batch[Postprocessing.VALUE_TARGETS],
model.value_function(),
mask,
policy.config["vf_loss_coeff"],
policy.entropy_coeff,
policy.config.get("use_critic", True),
)
return policy.loss.total_loss


def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]:
return {SampleBatch.VF_PREDS: policy.model.value_function()}


def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64),
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
"policy_loss": policy.loss.pi_loss,
"policy_entropy": policy.loss.entropy,
"var_gnorm": tf.linalg.global_norm(list(policy.model.trainable_variables())),
"vf_loss": policy.loss.vf_loss,
}


def grad_stats(
policy: Policy, train_batch: SampleBatch, grads: ModelGradients
) -> Dict[str, TensorType]:
return {
"grad_gnorm": tf.linalg.global_norm(grads),
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function()
),
}


def clip_gradients(
policy: Policy, optimizer: LocalOptimizer, loss: TensorType
) -> ModelGradients:
grads_and_vars = optimizer.compute_gradients(
loss, policy.model.trainable_variables()
)
grads = [g for (g, v) in grads_and_vars]
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"])
clipped_grads = list(zip(grads, policy.model.trainable_variables()))
return clipped_grads


def setup_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
ValueNetworkMixin.__init__(policy, config)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
EntropyCoeffSchedule.__init__(
policy, config["entropy_coeff"], config["entropy_coeff_schedule"]
)


A3CTFPolicy = build_tf_policy(
name="A3CTFPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
grad_stats_fn=grad_stats,
compute_gradients_fn=clip_gradients,
postprocess_fn=compute_gae_for_sample_batch,
extra_action_out_fn=add_value_function_fetch,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule],
error=True,
)
def postprocess_advantages(*args, **kwargs):
pass
301 changes: 139 additions & 162 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,188 +1,165 @@
import gym
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Type, Union

import ray
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.postprocessing import (
compute_gae_for_sample_batch,
Postprocessing,
)
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_mixins import (
EntropyCoeffSchedule,
LearningRateSchedule,
ValueNetworkMixin,
)
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import apply_grad_clipping, sequence_mask
from ray.rllib.utils.typing import (
TrainerConfigDict,
TensorType,
PolicyID,
LocalOptimizer,
)
from ray.rllib.utils.typing import AgentID, TensorType

torch, nn = try_import_torch()


@Deprecated(
old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False,
)
def add_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
episode: Optional[Episode] = None,
) -> SampleBatch:

return compute_gae_for_sample_batch(
policy, sample_batch, other_agent_batches, episode
)


def actor_critic_loss(
policy: Policy,
model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch,
) -> TensorType:
logits, _ = model(train_batch)
values = model.value_function()

if policy.is_recurrent():
B = len(train_batch[SampleBatch.SEQ_LENS])
max_seq_len = logits.shape[0] // B
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
valid_mask = torch.reshape(mask_orig, [-1])
else:
valid_mask = torch.ones_like(values, dtype=torch.bool)

dist = dist_class(logits, model)
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
pi_err = -torch.sum(
torch.masked_select(
log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask
class A3CTorchPolicy(
ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, TorchPolicyV2
):
"""PyTorch Policy class used with A3CTrainer."""

def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.A3CConfig().to_dict(), **config)

TorchPolicyV2.__init__(
self,
observation_space,
action_space,
config,
max_seq_len=config["model"]["max_seq_len"],
)
)

# Compute a value function loss.
if policy.config["use_critic"]:
value_err = 0.5 * torch.sum(
torch.pow(
torch.masked_select(
values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS],
valid_mask,
),
2.0,
ValueNetworkMixin.__init__(self, config)
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"])
EntropyCoeffSchedule.__init__(
self, config["entropy_coeff"], config["entropy_coeff_schedule"]
)

# TODO: Don't require users to call this manually.
self._initialize_loss_from_dummy_batch()

@override(TorchPolicyV2)
def loss(
self,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss function.
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
The A3C loss tensor given the input batch.
"""
logits, _ = model(train_batch)
values = model.value_function()

if self.is_recurrent():
B = len(train_batch[SampleBatch.SEQ_LENS])
max_seq_len = logits.shape[0] // B
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
valid_mask = torch.reshape(mask_orig, [-1])
else:
valid_mask = torch.ones_like(values, dtype=torch.bool)

dist = dist_class(logits, model)
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1)
pi_err = -torch.sum(
torch.masked_select(
log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask
)
)

# Compute a value function loss.
if self.config["use_critic"]:
value_err = 0.5 * torch.sum(
torch.pow(
torch.masked_select(
values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS],
valid_mask,
),
2.0,
)
)
# Ignore the value function.
else:
value_err = 0.0

entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

total_loss = (
pi_err
+ value_err * self.config["vf_loss_coeff"]
- entropy * self.entropy_coeff
)

# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["entropy"] = entropy
model.tower_stats["pi_err"] = pi_err
model.tower_stats["value_err"] = value_err

return total_loss

@override(TorchPolicyV2)
def optimizer(
self,
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""Returns a torch optimizer (Adam) for A3C."""
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])

@override(TorchPolicyV2)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
return convert_to_numpy(
{
"cur_lr": self.cur_lr,
"entropy_coeff": self.entropy_coeff,
"policy_entropy": torch.mean(
torch.stack(self.get_tower_stats("entropy"))
),
"policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_err"))),
"vf_loss": torch.mean(torch.stack(self.get_tower_stats("value_err"))),
}
)

@override(TorchPolicyV2)
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[Episode] = None,
):
sample_batch = super().postprocess_trajectory(sample_batch)
return compute_gae_for_sample_batch(
self, sample_batch, other_agent_batches, episode
)
# Ignore the value function.
else:
value_err = 0.0

entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask))

total_loss = (
pi_err
+ value_err * policy.config["vf_loss_coeff"]
- entropy * policy.entropy_coeff
)

# Store values for stats function in model (tower), such that for
# multi-GPU, we do not override them during the parallel loss phase.
model.tower_stats["entropy"] = entropy
model.tower_stats["pi_err"] = pi_err
model.tower_stats["value_err"] = value_err

return total_loss


def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:

return {
"cur_lr": policy.cur_lr,
"entropy_coeff": policy.entropy_coeff,
"policy_entropy": torch.mean(torch.stack(policy.get_tower_stats("entropy"))),
"policy_loss": torch.mean(torch.stack(policy.get_tower_stats("pi_err"))),
"vf_loss": torch.mean(torch.stack(policy.get_tower_stats("value_err"))),
}


def vf_preds_fetches(
policy: Policy,
input_dict: Dict[str, TensorType],
state_batches: List[TensorType],
model: ModelV2,
action_dist: TorchDistributionWrapper,
) -> Dict[str, TensorType]:
"""Defines extra fetches per action computation.
Args:
policy (Policy): The Policy to perform the extra action fetch on.
input_dict (Dict[str, TensorType]): The input dict used for the action
computing forward pass.
state_batches (List[TensorType]): List of state tensors (empty for
non-RNNs).
model (ModelV2): The Model object of the Policy.
action_dist (TorchDistributionWrapper): The instantiated distribution
object, resulting from the model's outputs and the given
distribution class.
Returns:
Dict[str, TensorType]: Dict with extra tf fetches to perform per
action computation.
"""
# Return value function outputs. VF estimates will hence be added to the
# SampleBatches produced by the sampler(s) to generate the train batches
# going into the loss function.
return {
SampleBatch.VF_PREDS: model.value_function(),
}


def torch_optimizer(policy: Policy, config: TrainerConfigDict) -> LocalOptimizer:
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])


def setup_mixins(
policy: Policy,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: TrainerConfigDict,
) -> None:
"""Call all mixin classes' constructors before PPOPolicy initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
EntropyCoeffSchedule.__init__(
policy, config["entropy_coeff"], config["entropy_coeff_schedule"]
)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
ValueNetworkMixin.__init__(policy, config)


A3CTorchPolicy = build_policy_class(
name="A3CTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=stats,
postprocess_fn=compute_gae_for_sample_batch,
extra_action_out_fn=vf_preds_fetches,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule],

@override(TorchPolicyV2)
def extra_grad_process(
self, optimizer: "torch.optim.Optimizer", loss: TensorType
) -> Dict[str, TensorType]:
return apply_grad_clipping(self, optimizer, loss)


@Deprecated(
old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=True,
)
def add_advantages(*args, **kwargs):
pass
12 changes: 9 additions & 3 deletions rllib/agents/a3c/tests/test_a2c.py
Original file line number Diff line number Diff line change
@@ -36,9 +36,14 @@ def test_a2c_compilation(self):
trainer.stop()

def test_a2c_exec_impl(self):
config = a3c.A2CConfig().reporting(min_time_s_per_reporting=0)
config = (
a3c.A2CConfig()
.environment(env="CartPole-v0")
.reporting(min_time_s_per_reporting=0)
)

for _ in framework_iterator(config):
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
trainer = config.build()
results = trainer.train()
check_train_results(results)
print(results)
@@ -48,12 +53,13 @@ def test_a2c_exec_impl(self):
def test_a2c_exec_impl_microbatch(self):
config = (
a3c.A2CConfig()
.environment(env="CartPole-v0")
.reporting(min_time_s_per_reporting=0)
.training(microbatch_size=10)
)

for _ in framework_iterator(config):
trainer = config.build(env="CartPole-v0")
trainer = config.build()
results = trainer.train()
check_train_results(results)
print(results)
3 changes: 2 additions & 1 deletion rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
LocalOptimizer,
ModelGradients,
TensorType,
TFPolicyV2Type,
TrainerConfigDict,
)

@@ -52,7 +53,7 @@ def validate_config(config: TrainerConfigDict) -> None:

# We need this builder function because we want to share the same
# custom logics between TF1 dynamic and TF2 eager policies.
def get_ppo_tf_policy(base: type) -> type:
def get_ppo_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type:
"""Construct a PPOTFPolicy inheriting either dynamic or eager base policies.
Args:
7 changes: 6 additions & 1 deletion rllib/policy/tf_mixins.py
Original file line number Diff line number Diff line change
@@ -342,7 +342,12 @@ def compute_gradients(
# If the global_norm is inf -> All grads will be NaN. Stabilize this
# here by setting them to 0.0. This will simply ignore destructive loss
# calculations.
policy.grads = [tf.where(tf.math.is_nan(g), tf.zeros_like(g), g) for g in grads]
policy.grads = []
for g in grads:
if g is not None:
policy.grads.append(tf.where(tf.math.is_nan(g), tf.zeros_like(g), g))
else:
policy.grads.append(None)
clipped_grads_and_vars = list(zip(policy.grads, variables))
return clipped_grads_and_vars
else:
8 changes: 7 additions & 1 deletion rllib/utils/typing.py
Original file line number Diff line number Diff line change
@@ -6,13 +6,16 @@
List,
Optional,
Tuple,
Union,
Type,
TypeVar,
TYPE_CHECKING,
Union,
)

if TYPE_CHECKING:
from ray.rllib.env.env_context import EnvContext
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.view_requirement import ViewRequirement
@@ -80,6 +83,9 @@
# data (TensorStructType).
PolicyState = Dict[str, TensorStructType]

# Any tf Policy type (static-graph or eager Policy).
TFPolicyV2Type = Type[Union["DynamicTFPolicyV2", "EagerTFPolicyV2"]]

# Represents an episode id.
EpisodeID = int

0 comments on commit ab6c302

Please sign in to comment.