forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib] A2/3C policy sub-classing schema. (ray-project#25078)
Showing
7 changed files
with
335 additions
and
319 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,178 +1,194 @@ | ||
"""Note: Keep in sync with changes to VTraceTFPolicy.""" | ||
from typing import Optional, Dict | ||
import gym | ||
from typing import Dict, List, Optional, Type, Union | ||
|
||
import ray | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.evaluation.episode import Episode | ||
from ray.rllib.evaluation.postprocessing import ( | ||
compute_gae_for_sample_batch, | ||
Postprocessing, | ||
) | ||
from ray.rllib.models.action_dist import ActionDistribution | ||
from ray.rllib.models.modelv2 import ModelV2 | ||
from ray.rllib.policy.policy import Policy | ||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution | ||
from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 | ||
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.policy.tf_mixins import ( | ||
compute_gradients, | ||
EntropyCoeffSchedule, | ||
LearningRateSchedule, | ||
ValueNetworkMixin, | ||
) | ||
from ray.rllib.policy.tf_policy_template import build_tf_policy | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.deprecation import Deprecated | ||
from ray.rllib.utils.framework import try_import_tf | ||
from ray.rllib.utils.tf_utils import explained_variance | ||
from ray.rllib.utils.typing import ( | ||
TrainerConfigDict, | ||
TensorType, | ||
PolicyID, | ||
AgentID, | ||
LocalOptimizer, | ||
ModelGradients, | ||
TensorType, | ||
TFPolicyV2Type, | ||
) | ||
|
||
tf1, tf, tfv = try_import_tf() | ||
|
||
|
||
# We need this builder function because we want to share the same | ||
# custom logics between TF1 dynamic and TF2 eager policies. | ||
def get_a3c_tf_policy(base: TFPolicyV2Type) -> TFPolicyV2Type: | ||
"""Construct a A3CTFPolicy inheriting either dynamic or eager base policies. | ||
Args: | ||
base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. | ||
Returns: | ||
A TF Policy to be used with MAMLTrainer. | ||
""" | ||
|
||
class A3CTFPolicy( | ||
ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, base | ||
): | ||
def __init__( | ||
self, | ||
obs_space, | ||
action_space, | ||
config, | ||
existing_model=None, | ||
existing_inputs=None, | ||
): | ||
# First thing first, enable eager execution if necessary. | ||
base.enable_eager_execution_if_necessary() | ||
|
||
config = dict(ray.rllib.agents.a3c.a3c.A3CConfig().to_dict(), **config) | ||
|
||
# Initialize base class. | ||
base.__init__( | ||
self, | ||
obs_space, | ||
action_space, | ||
config, | ||
existing_inputs=existing_inputs, | ||
existing_model=existing_model, | ||
) | ||
|
||
ValueNetworkMixin.__init__(self, self.config) | ||
LearningRateSchedule.__init__( | ||
self, self.config["lr"], self.config["lr_schedule"] | ||
) | ||
EntropyCoeffSchedule.__init__( | ||
self, config["entropy_coeff"], config["entropy_coeff_schedule"] | ||
) | ||
|
||
# Note: this is a bit ugly, but loss and optimizer initialization must | ||
# happen after all the MixIns are initialized. | ||
self.maybe_initialize_optimizer_and_loss() | ||
|
||
@override(base) | ||
def loss( | ||
self, | ||
model: Union[ModelV2, "tf.keras.Model"], | ||
dist_class: Type[TFActionDistribution], | ||
train_batch: SampleBatch, | ||
) -> Union[TensorType, List[TensorType]]: | ||
|
||
model_out, _ = model(train_batch) | ||
action_dist = dist_class(model_out, model) | ||
if self.is_recurrent(): | ||
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) | ||
valid_mask = tf.sequence_mask( | ||
train_batch[SampleBatch.SEQ_LENS], max_seq_len | ||
) | ||
valid_mask = tf.reshape(valid_mask, [-1]) | ||
else: | ||
valid_mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) | ||
|
||
log_prob = action_dist.logp(train_batch[SampleBatch.ACTIONS]) | ||
vf = model.value_function() | ||
|
||
# The "policy gradients" loss | ||
self.pi_loss = -tf.reduce_sum( | ||
tf.boolean_mask( | ||
log_prob * train_batch[Postprocessing.ADVANTAGES], valid_mask | ||
) | ||
) | ||
|
||
delta = tf.boolean_mask( | ||
vf - train_batch[Postprocessing.VALUE_TARGETS], valid_mask | ||
) | ||
|
||
# Compute a value function loss. | ||
if self.config.get("use_critic", True): | ||
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta)) | ||
# Ignore the value function. | ||
else: | ||
self.vf_loss = tf.constant(0.0) | ||
|
||
self.entropy_loss = tf.reduce_sum( | ||
tf.boolean_mask(action_dist.entropy(), valid_mask) | ||
) | ||
|
||
self.total_loss = ( | ||
self.pi_loss | ||
+ self.vf_loss * self.config["vf_loss_coeff"] | ||
- self.entropy_loss * self.entropy_coeff | ||
) | ||
|
||
return self.total_loss | ||
|
||
@override(base) | ||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: | ||
return { | ||
"cur_lr": tf.cast(self.cur_lr, tf.float64), | ||
"entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), | ||
"policy_loss": self.pi_loss, | ||
"policy_entropy": self.entropy_loss, | ||
"var_gnorm": tf.linalg.global_norm( | ||
list(self.model.trainable_variables()) | ||
), | ||
"vf_loss": self.vf_loss, | ||
} | ||
|
||
@override(base) | ||
def grad_stats_fn( | ||
self, train_batch: SampleBatch, grads: ModelGradients | ||
) -> Dict[str, TensorType]: | ||
return { | ||
"grad_gnorm": tf.linalg.global_norm(grads), | ||
"vf_explained_var": explained_variance( | ||
train_batch[Postprocessing.VALUE_TARGETS], | ||
self.model.value_function(), | ||
), | ||
} | ||
|
||
@override(base) | ||
def postprocess_trajectory( | ||
self, | ||
sample_batch: SampleBatch, | ||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, | ||
episode: Optional[Episode] = None, | ||
): | ||
sample_batch = super().postprocess_trajectory(sample_batch) | ||
return compute_gae_for_sample_batch( | ||
self, sample_batch, other_agent_batches, episode | ||
) | ||
|
||
@override(base) | ||
def compute_gradients_fn( | ||
self, optimizer: LocalOptimizer, loss: TensorType | ||
) -> ModelGradients: | ||
return compute_gradients(self, optimizer, loss) | ||
|
||
return A3CTFPolicy | ||
|
||
|
||
A3CStaticGraphTFPolicy = get_a3c_tf_policy(DynamicTFPolicyV2) | ||
A3CEagerTFPolicy = get_a3c_tf_policy(EagerTFPolicyV2) | ||
|
||
|
||
@Deprecated( | ||
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages", | ||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", | ||
error=False, | ||
) | ||
def postprocess_advantages( | ||
policy: Policy, | ||
sample_batch: SampleBatch, | ||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, | ||
episode: Optional[Episode] = None, | ||
) -> SampleBatch: | ||
|
||
return compute_gae_for_sample_batch( | ||
policy, sample_batch, other_agent_batches, episode | ||
) | ||
|
||
|
||
class A3CLoss: | ||
def __init__( | ||
self, | ||
action_dist: ActionDistribution, | ||
actions: TensorType, | ||
advantages: TensorType, | ||
v_target: TensorType, | ||
vf: TensorType, | ||
valid_mask: TensorType, | ||
vf_loss_coeff: float = 0.5, | ||
entropy_coeff: float = 0.01, | ||
use_critic: bool = True, | ||
): | ||
log_prob = action_dist.logp(actions) | ||
|
||
# The "policy gradients" loss | ||
self.pi_loss = -tf.reduce_sum( | ||
tf.boolean_mask(log_prob * advantages, valid_mask) | ||
) | ||
|
||
delta = tf.boolean_mask(vf - v_target, valid_mask) | ||
|
||
# Compute a value function loss. | ||
if use_critic: | ||
self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta)) | ||
# Ignore the value function. | ||
else: | ||
self.vf_loss = tf.constant(0.0) | ||
|
||
self.entropy = tf.reduce_sum(tf.boolean_mask(action_dist.entropy(), valid_mask)) | ||
|
||
self.total_loss = ( | ||
self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff | ||
) | ||
|
||
|
||
def actor_critic_loss( | ||
policy: Policy, | ||
model: ModelV2, | ||
dist_class: ActionDistribution, | ||
train_batch: SampleBatch, | ||
) -> TensorType: | ||
model_out, _ = model(train_batch) | ||
action_dist = dist_class(model_out, model) | ||
if policy.is_recurrent(): | ||
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) | ||
mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) | ||
mask = tf.reshape(mask, [-1]) | ||
else: | ||
mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) | ||
policy.loss = A3CLoss( | ||
action_dist, | ||
train_batch[SampleBatch.ACTIONS], | ||
train_batch[Postprocessing.ADVANTAGES], | ||
train_batch[Postprocessing.VALUE_TARGETS], | ||
model.value_function(), | ||
mask, | ||
policy.config["vf_loss_coeff"], | ||
policy.entropy_coeff, | ||
policy.config.get("use_critic", True), | ||
) | ||
return policy.loss.total_loss | ||
|
||
|
||
def add_value_function_fetch(policy: Policy) -> Dict[str, TensorType]: | ||
return {SampleBatch.VF_PREDS: policy.model.value_function()} | ||
|
||
|
||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: | ||
return { | ||
"cur_lr": tf.cast(policy.cur_lr, tf.float64), | ||
"entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), | ||
"policy_loss": policy.loss.pi_loss, | ||
"policy_entropy": policy.loss.entropy, | ||
"var_gnorm": tf.linalg.global_norm(list(policy.model.trainable_variables())), | ||
"vf_loss": policy.loss.vf_loss, | ||
} | ||
|
||
|
||
def grad_stats( | ||
policy: Policy, train_batch: SampleBatch, grads: ModelGradients | ||
) -> Dict[str, TensorType]: | ||
return { | ||
"grad_gnorm": tf.linalg.global_norm(grads), | ||
"vf_explained_var": explained_variance( | ||
train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function() | ||
), | ||
} | ||
|
||
|
||
def clip_gradients( | ||
policy: Policy, optimizer: LocalOptimizer, loss: TensorType | ||
) -> ModelGradients: | ||
grads_and_vars = optimizer.compute_gradients( | ||
loss, policy.model.trainable_variables() | ||
) | ||
grads = [g for (g, v) in grads_and_vars] | ||
grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) | ||
clipped_grads = list(zip(grads, policy.model.trainable_variables())) | ||
return clipped_grads | ||
|
||
|
||
def setup_mixins( | ||
policy: Policy, | ||
obs_space: gym.spaces.Space, | ||
action_space: gym.spaces.Space, | ||
config: TrainerConfigDict, | ||
) -> None: | ||
ValueNetworkMixin.__init__(policy, config) | ||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) | ||
EntropyCoeffSchedule.__init__( | ||
policy, config["entropy_coeff"], config["entropy_coeff_schedule"] | ||
) | ||
|
||
|
||
A3CTFPolicy = build_tf_policy( | ||
name="A3CTFPolicy", | ||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, | ||
loss_fn=actor_critic_loss, | ||
stats_fn=stats, | ||
grad_stats_fn=grad_stats, | ||
compute_gradients_fn=clip_gradients, | ||
postprocess_fn=compute_gae_for_sample_batch, | ||
extra_action_out_fn=add_value_function_fetch, | ||
before_loss_init=setup_mixins, | ||
mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule], | ||
error=True, | ||
) | ||
def postprocess_advantages(*args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,188 +1,165 @@ | ||
import gym | ||
from typing import Dict, List, Optional | ||
from typing import Dict, List, Optional, Type, Union | ||
|
||
import ray | ||
from ray.rllib.evaluation.episode import Episode | ||
from ray.rllib.evaluation.postprocessing import ( | ||
compute_gae_for_sample_batch, | ||
Postprocessing, | ||
) | ||
from ray.rllib.models.action_dist import ActionDistribution | ||
from ray.rllib.models.modelv2 import ModelV2 | ||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper | ||
from ray.rllib.policy.policy import Policy | ||
from ray.rllib.policy.policy_template import build_policy_class | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.policy.torch_mixins import ( | ||
EntropyCoeffSchedule, | ||
LearningRateSchedule, | ||
ValueNetworkMixin, | ||
) | ||
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.deprecation import Deprecated | ||
from ray.rllib.utils.framework import try_import_torch | ||
from ray.rllib.utils.numpy import convert_to_numpy | ||
from ray.rllib.utils.torch_utils import apply_grad_clipping, sequence_mask | ||
from ray.rllib.utils.typing import ( | ||
TrainerConfigDict, | ||
TensorType, | ||
PolicyID, | ||
LocalOptimizer, | ||
) | ||
from ray.rllib.utils.typing import AgentID, TensorType | ||
|
||
torch, nn = try_import_torch() | ||
|
||
|
||
@Deprecated( | ||
old="rllib.agents.a3c.a3c_torch_policy.add_advantages", | ||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", | ||
error=False, | ||
) | ||
def add_advantages( | ||
policy: Policy, | ||
sample_batch: SampleBatch, | ||
other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, | ||
episode: Optional[Episode] = None, | ||
) -> SampleBatch: | ||
|
||
return compute_gae_for_sample_batch( | ||
policy, sample_batch, other_agent_batches, episode | ||
) | ||
|
||
|
||
def actor_critic_loss( | ||
policy: Policy, | ||
model: ModelV2, | ||
dist_class: ActionDistribution, | ||
train_batch: SampleBatch, | ||
) -> TensorType: | ||
logits, _ = model(train_batch) | ||
values = model.value_function() | ||
|
||
if policy.is_recurrent(): | ||
B = len(train_batch[SampleBatch.SEQ_LENS]) | ||
max_seq_len = logits.shape[0] // B | ||
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) | ||
valid_mask = torch.reshape(mask_orig, [-1]) | ||
else: | ||
valid_mask = torch.ones_like(values, dtype=torch.bool) | ||
|
||
dist = dist_class(logits, model) | ||
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) | ||
pi_err = -torch.sum( | ||
torch.masked_select( | ||
log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask | ||
class A3CTorchPolicy( | ||
ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, TorchPolicyV2 | ||
): | ||
"""PyTorch Policy class used with A3CTrainer.""" | ||
|
||
def __init__(self, observation_space, action_space, config): | ||
config = dict(ray.rllib.agents.a3c.a3c.A3CConfig().to_dict(), **config) | ||
|
||
TorchPolicyV2.__init__( | ||
self, | ||
observation_space, | ||
action_space, | ||
config, | ||
max_seq_len=config["model"]["max_seq_len"], | ||
) | ||
) | ||
|
||
# Compute a value function loss. | ||
if policy.config["use_critic"]: | ||
value_err = 0.5 * torch.sum( | ||
torch.pow( | ||
torch.masked_select( | ||
values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], | ||
valid_mask, | ||
), | ||
2.0, | ||
ValueNetworkMixin.__init__(self, config) | ||
LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) | ||
EntropyCoeffSchedule.__init__( | ||
self, config["entropy_coeff"], config["entropy_coeff_schedule"] | ||
) | ||
|
||
# TODO: Don't require users to call this manually. | ||
self._initialize_loss_from_dummy_batch() | ||
|
||
@override(TorchPolicyV2) | ||
def loss( | ||
self, | ||
model: ModelV2, | ||
dist_class: Type[TorchDistributionWrapper], | ||
train_batch: SampleBatch, | ||
) -> Union[TensorType, List[TensorType]]: | ||
"""Constructs the loss function. | ||
Args: | ||
model: The Model to calculate the loss for. | ||
dist_class: The action distr. class. | ||
train_batch: The training data. | ||
Returns: | ||
The A3C loss tensor given the input batch. | ||
""" | ||
logits, _ = model(train_batch) | ||
values = model.value_function() | ||
|
||
if self.is_recurrent(): | ||
B = len(train_batch[SampleBatch.SEQ_LENS]) | ||
max_seq_len = logits.shape[0] // B | ||
mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) | ||
valid_mask = torch.reshape(mask_orig, [-1]) | ||
else: | ||
valid_mask = torch.ones_like(values, dtype=torch.bool) | ||
|
||
dist = dist_class(logits, model) | ||
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) | ||
pi_err = -torch.sum( | ||
torch.masked_select( | ||
log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask | ||
) | ||
) | ||
|
||
# Compute a value function loss. | ||
if self.config["use_critic"]: | ||
value_err = 0.5 * torch.sum( | ||
torch.pow( | ||
torch.masked_select( | ||
values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], | ||
valid_mask, | ||
), | ||
2.0, | ||
) | ||
) | ||
# Ignore the value function. | ||
else: | ||
value_err = 0.0 | ||
|
||
entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) | ||
|
||
total_loss = ( | ||
pi_err | ||
+ value_err * self.config["vf_loss_coeff"] | ||
- entropy * self.entropy_coeff | ||
) | ||
|
||
# Store values for stats function in model (tower), such that for | ||
# multi-GPU, we do not override them during the parallel loss phase. | ||
model.tower_stats["entropy"] = entropy | ||
model.tower_stats["pi_err"] = pi_err | ||
model.tower_stats["value_err"] = value_err | ||
|
||
return total_loss | ||
|
||
@override(TorchPolicyV2) | ||
def optimizer( | ||
self, | ||
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: | ||
"""Returns a torch optimizer (Adam) for A3C.""" | ||
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) | ||
|
||
@override(TorchPolicyV2) | ||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: | ||
return convert_to_numpy( | ||
{ | ||
"cur_lr": self.cur_lr, | ||
"entropy_coeff": self.entropy_coeff, | ||
"policy_entropy": torch.mean( | ||
torch.stack(self.get_tower_stats("entropy")) | ||
), | ||
"policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_err"))), | ||
"vf_loss": torch.mean(torch.stack(self.get_tower_stats("value_err"))), | ||
} | ||
) | ||
|
||
@override(TorchPolicyV2) | ||
def postprocess_trajectory( | ||
self, | ||
sample_batch: SampleBatch, | ||
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, | ||
episode: Optional[Episode] = None, | ||
): | ||
sample_batch = super().postprocess_trajectory(sample_batch) | ||
return compute_gae_for_sample_batch( | ||
self, sample_batch, other_agent_batches, episode | ||
) | ||
# Ignore the value function. | ||
else: | ||
value_err = 0.0 | ||
|
||
entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) | ||
|
||
total_loss = ( | ||
pi_err | ||
+ value_err * policy.config["vf_loss_coeff"] | ||
- entropy * policy.entropy_coeff | ||
) | ||
|
||
# Store values for stats function in model (tower), such that for | ||
# multi-GPU, we do not override them during the parallel loss phase. | ||
model.tower_stats["entropy"] = entropy | ||
model.tower_stats["pi_err"] = pi_err | ||
model.tower_stats["value_err"] = value_err | ||
|
||
return total_loss | ||
|
||
|
||
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: | ||
|
||
return { | ||
"cur_lr": policy.cur_lr, | ||
"entropy_coeff": policy.entropy_coeff, | ||
"policy_entropy": torch.mean(torch.stack(policy.get_tower_stats("entropy"))), | ||
"policy_loss": torch.mean(torch.stack(policy.get_tower_stats("pi_err"))), | ||
"vf_loss": torch.mean(torch.stack(policy.get_tower_stats("value_err"))), | ||
} | ||
|
||
|
||
def vf_preds_fetches( | ||
policy: Policy, | ||
input_dict: Dict[str, TensorType], | ||
state_batches: List[TensorType], | ||
model: ModelV2, | ||
action_dist: TorchDistributionWrapper, | ||
) -> Dict[str, TensorType]: | ||
"""Defines extra fetches per action computation. | ||
Args: | ||
policy (Policy): The Policy to perform the extra action fetch on. | ||
input_dict (Dict[str, TensorType]): The input dict used for the action | ||
computing forward pass. | ||
state_batches (List[TensorType]): List of state tensors (empty for | ||
non-RNNs). | ||
model (ModelV2): The Model object of the Policy. | ||
action_dist (TorchDistributionWrapper): The instantiated distribution | ||
object, resulting from the model's outputs and the given | ||
distribution class. | ||
Returns: | ||
Dict[str, TensorType]: Dict with extra tf fetches to perform per | ||
action computation. | ||
""" | ||
# Return value function outputs. VF estimates will hence be added to the | ||
# SampleBatches produced by the sampler(s) to generate the train batches | ||
# going into the loss function. | ||
return { | ||
SampleBatch.VF_PREDS: model.value_function(), | ||
} | ||
|
||
|
||
def torch_optimizer(policy: Policy, config: TrainerConfigDict) -> LocalOptimizer: | ||
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) | ||
|
||
|
||
def setup_mixins( | ||
policy: Policy, | ||
obs_space: gym.spaces.Space, | ||
action_space: gym.spaces.Space, | ||
config: TrainerConfigDict, | ||
) -> None: | ||
"""Call all mixin classes' constructors before PPOPolicy initialization. | ||
Args: | ||
policy (Policy): The Policy object. | ||
obs_space (gym.spaces.Space): The Policy's observation space. | ||
action_space (gym.spaces.Space): The Policy's action space. | ||
config (TrainerConfigDict): The Policy's config. | ||
""" | ||
EntropyCoeffSchedule.__init__( | ||
policy, config["entropy_coeff"], config["entropy_coeff_schedule"] | ||
) | ||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) | ||
ValueNetworkMixin.__init__(policy, config) | ||
|
||
|
||
A3CTorchPolicy = build_policy_class( | ||
name="A3CTorchPolicy", | ||
framework="torch", | ||
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, | ||
loss_fn=actor_critic_loss, | ||
stats_fn=stats, | ||
postprocess_fn=compute_gae_for_sample_batch, | ||
extra_action_out_fn=vf_preds_fetches, | ||
extra_grad_process_fn=apply_grad_clipping, | ||
optimizer_fn=torch_optimizer, | ||
before_loss_init=setup_mixins, | ||
mixins=[ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule], | ||
|
||
@override(TorchPolicyV2) | ||
def extra_grad_process( | ||
self, optimizer: "torch.optim.Optimizer", loss: TensorType | ||
) -> Dict[str, TensorType]: | ||
return apply_grad_clipping(self, optimizer, loss) | ||
|
||
|
||
@Deprecated( | ||
old="rllib.agents.a3c.a3c_torch_policy.add_advantages", | ||
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch", | ||
error=True, | ||
) | ||
def add_advantages(*args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters