Skip to content

Move common loss functions for PPO and POCA #5079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 4 additions & 60 deletions ml-agents/mlagents/trainers/poca/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,63 +201,6 @@ def create_reward_signals(
def critic(self):
return self._critic

def poca_value_loss(
self,
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluates value loss for POCA.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss

def poca_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluate POCA policy loss.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)

decay_epsilon = self.hyperparameters.epsilon
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss

@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Expand Down Expand Up @@ -346,17 +289,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)

baseline_loss = self.poca_value_loss(
baseline_loss = ModelUtils.trust_region_value_loss(
baselines, old_baseline_values, returns, decay_eps, loss_masks
)
value_loss = self.poca_value_loss(
value_loss = ModelUtils.trust_region_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
policy_loss = self.poca_policy_loss(
policy_loss = ModelUtils.trust_region_policy_loss(
ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]),
log_probs,
old_log_probs,
loss_masks,
decay_eps,
)
loss = (
policy_loss
Expand Down
62 changes: 3 additions & 59 deletions ml-agents/mlagents/trainers/ppo/optimizer_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,63 +76,6 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
def critic(self):
return self._critic

def ppo_value_loss(
self,
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluates value loss for PPO.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss

def ppo_policy_loss(
self,
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluate PPO policy loss.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)

decay_epsilon = self.hyperparameters.epsilon
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss

@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Expand Down Expand Up @@ -195,14 +138,15 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
value_loss = self.ppo_value_loss(
value_loss = ModelUtils.trust_region_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
policy_loss = self.ppo_policy_loss(
policy_loss = ModelUtils.trust_region_policy_loss(
ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]),
log_probs,
old_log_probs,
loss_masks,
decay_eps,
)
loss = (
policy_loss
Expand Down
57 changes: 56 additions & 1 deletion ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch.layers import LinearEncoder, Initialization
import numpy as np
Expand Down Expand Up @@ -428,3 +428,58 @@ def encode_observations(
)

return encoded_self

@staticmethod
def trust_region_value_loss(
values: Dict[str, torch.Tensor],
old_values: Dict[str, torch.Tensor],
returns: Dict[str, torch.Tensor],
epsilon: float,
loss_masks: torch.Tensor,
) -> torch.Tensor:
"""
Evaluates value loss, clipping to stay within a trust region of old value estimates.
Used for PPO and POCA.
:param values: Value output of the current network.
:param old_values: Value stored with experiences in buffer.
:param returns: Computed returns.
:param epsilon: Clipping value for value estimate.
:param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
value_losses = []
for name, head in values.items():
old_val_tensor = old_values[name]
returns_tensor = returns[name]
clipped_value_estimate = old_val_tensor + torch.clamp(
head - old_val_tensor, -1 * epsilon, epsilon
)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss

@staticmethod
def trust_region_policy_loss(
advantages: torch.Tensor,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
loss_masks: torch.Tensor,
epsilon: float,
) -> torch.Tensor:
"""
Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
:param advantages: Computed advantages.
:param log_probs: Current policy probabilities
:param old_log_probs: Past policy probabilities
:param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
"""
advantage = advantages.unsqueeze(-1)
r_theta = torch.exp(log_probs - old_log_probs)
p_opt_a = r_theta * advantage
p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b), loss_masks
)
return policy_loss