Description
Describe the bug
When training on PettingZoo/MultiWalker-v9
with Multi-Agent Soft Actor-Critic
, all losses (loss_actor
, loss_qvalue
, loss_alpha
) explode after ~1M environment steps at most.
This phenomenon occurs regardless of (reasonable) hyperparameter and gradient clipping threshold choice.
To Reproduce
from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box
import logging
import math
import torch
from torch import nn
import torch.distributions as D
from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import (
check_env_specs,
PettingZooEnv,
ParallelEnv,
GymEnv
)
from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import (
MultiAgentMLP,
MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector, RandomPolicy
from torchrl.objectives import SACLoss, SoftUpdate
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv, Compose, Transform, RewardSum, ObservationNorm, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform
import multiprocessing as mp
EPS = 1e-7
class SMACCNet(MultiAgentNetBase):
'''
This is an MLP policy network for MultiAgent SAC.
This is just a more limited version of MultiAgentMLP.
(https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html)
'''
def __init__(self,
n_agent_inputs: int | None,
n_agent_outputs: int,
n_agents: int,
centralised: bool,
share_params: bool,
device = 'cpu',
activation_class = nn.Tanh,
**kwargs):
self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
self.share_params = share_params
self.centralised = centralised
self.activation_class = activation_class
self.device = device
super().__init__(
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
agent_dim=-2,
device = device,
**kwargs,
)
# Copied over from MultiAgentMLP.
def _pre_forward_check(self, inputs):
if inputs.shape[-2] != self.n_agents:
raise ValueError(
f"Multi-agent network expected input with shape[-2]={self.n_agents},"
f" but got {inputs.shape}"
)
# If the model is centralized, agents have full observability
if self.centralised:
inputs = inputs.flatten(-2, -1)
return inputs
def init_net_params(self, net):
def init_layer_params(layer):
if isinstance(layer, nn.Linear):
weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
if 'bias' in layer.state_dict():
torch.nn.init.zeros_(layer.bias)
net.apply(init_layer_params)
return net
def _build_single_net(self, *, device, **kwargs):
n_agent_inputs = self.n_agent_inputs
if self.centralised and n_agent_inputs is not None:
n_agent_inputs = self.n_agent_inputs * self.n_agents
model = nn.Sequential(
nn.Linear(n_agent_inputs, 400),
self.activation_class(),
nn.Linear(400, 300),
self.activation_class(),
nn.Linear(300, self.n_agent_outputs)
).to(self.device) # Bandaid fix to use MultiSyncDataCollector
model = self.init_net_params(model)
return model
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
# Main Function
if __name__ == "__main__":
logging.basicConfig(level = logging.INFO)
logger = logging.getLogger(__name__)
logger.propagate = False
logger.addHandler(TqdmLoggingHandler())
mp.set_start_method("spawn", force = True)
NUM_AGENTS = 3
NUM_CRITICS = 2
NUM_EXPLORE_WORKERS = 8
EXPLORATION_STEPS = 30000
MAX_EPISODE_STEPS = 1000
DEVICE = "cuda"
REPLAY_BUFFER_SIZE = int(1e6)
VALUE_GAMMA = 0.99
MAX_GRAD_NORM = 1.0
BATCH_SIZE = 256
LR = 1e-4
UPDATE_STEPS_PER_EXPLORATION = 1500
WARMUP_STEPS = 0 #int(2e5)
TRAIN_TIMESTEPS = int(1e7)
EVAL_INTERVAL = 1 #int(9e4 // EXPLORATION_STEPS) # Every 500k steps or so, evaluate once.
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
# https://pytorch.org/rl/stable/tutorials/multiagent_competitive_ddpg.html
# More tutorials: https://pytorch.org/tutorials/advanced/pendulum.html
# Toy test: https://pettingzoo.farama.org/environments/sisl/multiwalker/def env_fn(mode, parallel = True):
def base_env_fn():
return PettingZooEnv(task = "multiwalker_v9",
parallel = True,
seed = 42,
n_walkers = NUM_AGENTS,
terminate_reward = -5.0,
forward_reward = 1.0,
fall_reward = -1.0,
shared_reward = False,
max_cycles = MAX_EPISODE_STEPS,
render_mode = mode,
device = "cpu"
)
if parallel:
# Don't use.
# https://discuss.pytorch.org/t/pettingzoo-trouble-running-multiple-marl-environments-in-parallel/203706/
env = lambda: ParallelEnv(num_workers = 4, # noqa: E731
create_env_fn = base_env_fn,
device = "cpu",
mp_start_method = "spawn",
serial_for_single = True
)
else:
env = base_env_fn # noqa: E731
def env_with_transforms():
# dummy_env = base_env_fn()
# dummy_obs_transform = ObservationNorm(in_keys = [("walker", "observation")], standard_normal = True)
# dummy_env = TransformedEnv(dummy_env, dummy_obs_transform)
# dummy_obs_transform.init_stats(10000)
init_env = env()
# obs_transform = ObservationNorm(loc = dummy_obs_transform.loc + EPS,
# scale = dummy_obs_transform.scale + EPS,
# in_keys = [("walker", "observation")],
# standard_normal = True
# )
init_env = TransformedEnv(init_env, Compose(
StepCounter(max_steps = MAX_EPISODE_STEPS),
RewardSum(
in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)],
out_keys = [("walker", "episode_reward")] * NUM_AGENTS,
reset_keys = ["_reset"] * NUM_AGENTS
),
# obs_transform
)
)
# del dummy_env, dummy_obs_transform
return init_env
return env_with_transforms
train_env = env_fn(None, parallel = False)()
if train_env.is_closed:
train_env.start()
eval_env = env_fn("rgb_array", parallel = False)()
video_recorder = VideoRecorder(
CSVLogger("multiwalker-toy-test", video_format = "mp4"),
tag = "rendered",
in_keys = ["pixels_record"]
)
# Call the parent's render function
eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
eval_env.append_transform(video_recorder)
if eval_env.is_closed:
eval_env.start()
check_env_specs(train_env)
check_env_specs(eval_env)
print(f"Action: {train_env.full_action_spec}, Reward: {train_env.full_reward_spec}, Done: {train_env.full_done_spec}, Observation: {train_env.full_observation_spec}")
print(f"group_map: {train_env.group_map}")
print(f"Action: {train_env.action_keys}, Reward: {train_env.reward_keys}, Done: {train_env.done_keys}")
# NOTE: The input and output spaces to be fed in are on a PER-AGENT basis.
# Basically, if you have 16 agents observing 3D velocity and outputting speed (the magnitude),
# n_agent_inputs = 3, n_agent_outputs = 1.
obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
action_dim = train_env.full_action_spec["walker", "action"].shape[-1]
policy_net = nn.Sequential(
SMACCNet(n_agent_inputs = obs_dim,
n_agent_outputs = 2 * action_dim,
n_agents = NUM_AGENTS,
centralised = False,
share_params = True,
device = "cpu",
activation_class = nn.LeakyReLU,
),
NormalParamExtractor(),
)
critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
n_agent_outputs = 1,
n_agents = NUM_AGENTS,
centralised = True,
share_params = True,
device = "cpu",
activation_class = nn.LeakyReLU,
)
# Hook our networks to TensorDictModules so they can be a part of the TensorDict pipeline...
policy_net_td_module = TensorDictModule(module = policy_net,
in_keys = [("walker", "observation")],
# NOTE: These outputs must match with the parameter names of the
# distribution you are using!
out_keys = [("walker", "loc"), ("walker", "scale")]
)
obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
in_keys = [("walker", "observation"), ("walker", "action")],
out_keys = [("walker", "obs_act")]
)
critic_net_td_module = TensorDictModule(module = critic_net,
in_keys = [("walker", "obs_act")],
out_keys = [("walker", "state_action_value")]
)
# Attach our raw policy network to a probabilistic actor
policy_actor = ProbabilisticActor(
module = policy_net_td_module,
spec = train_env.full_action_spec["walker", "action"],
in_keys = [("walker", "loc"), ("walker", "scale")],
out_keys = [("walker", "action")],
# TanhNormal is based off of pytorch, which as far as we know,
# implements a numerically stable log det jacobian.
distribution_class = TanhNormal,
distribution_kwargs = {
"min": train_env.full_action_spec["walker", "action"].space.low,
"max": train_env.full_action_spec["walker", "action"].space.high,
},
return_log_prob = True,
)
with torch.no_grad():
fake_td = train_env.fake_tensordict()
policy_actor(fake_td)
dora = AdditiveGaussianWrapper(
policy = policy_actor,
action_key = ("walker", "action"),
sigma_init = 0.3,
sigma_end = 0.1,
annealing_num_steps = TRAIN_TIMESTEPS // 2
)
critic_actor = TensorDictSequential(
obs_act_module, critic_net_td_module
)
collector = MultiSyncDataCollector(
[env_fn(None, parallel = False) for _ in range(NUM_EXPLORE_WORKERS)],
policy = dora,
frames_per_batch = BATCH_SIZE,
max_frames_per_traj = 0,
total_frames = TRAIN_TIMESTEPS,
device = "cpu",
reset_at_each_iter = False
)
replay_buffer = TensorDictReplayBuffer(
storage = LazyMemmapStorage(
REPLAY_BUFFER_SIZE, device = "cpu",
), # We will store up to memory_size multi-agent transitions
sampler = RandomSampler(),
batch_size = BATCH_SIZE, # We will sample batches of this size
)
sac_loss = SACLoss(policy_actor.to(DEVICE),
qvalue_network = critic_actor.to(DEVICE),
num_qvalue_nets = 2,
loss_function = "l2",
delay_qvalue = True,
alpha_init = 0.1
)
sac_loss.set_keys(
action = ("walker", "action"),
state_action_value = ("walker", "state_action_value"),
reward = ("walker", "reward"),
done = ("walker", "done"),
terminated = ("walker", "terminated"),
)
sac_loss.make_value_estimator(gamma = VALUE_GAMMA)
polyak_updater = SoftUpdate(sac_loss, tau = 0.005)
critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
actor_params = list(sac_loss.actor_network_params.flatten_keys().values())
optimizer_actor = torch.optim.Adam(
actor_params,
lr = LR,
weight_decay = 5e-4,
eps = EPS,
betas = (0.9, 0.98)
)
optimizer_critic = torch.optim.Adam(
critic_params,
lr = LR,
weight_decay = 5e-4,
eps = EPS,
betas = (0.9, 0.98)
)
optimizer_alpha = torch.optim.Adam(
[sac_loss.log_alpha],
lr = LR,
eps = EPS,
betas = (0.9, 0.98)
)
# breakpoint()
num_frames = 0
pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
total_frames = 0
backprop_ctr = 0
train_rews, ep_lengths = [], []
EXPLORATION_BATCHES = EXPLORATION_STEPS // BATCH_SIZE
for i, tensordict in enumerate(collector):
collector.update_policy_weights_()
pbar.update(tensordict.numel())
tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
total_frames += current_frames
backprop_ctr += 1
# Optimization steps
if total_frames >= WARMUP_STEPS and backprop_ctr > EXPLORATION_BATCHES:
backprop_ctr = 0
losses = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
alphas = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
for j in range(UPDATE_STEPS_PER_EXPLORATION):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if str(sampled_tensordict.device) != DEVICE:
sampled_tensordict = sampled_tensordict.to(DEVICE, non_blocking = False)
else:
sampled_tensordict = sampled_tensordict.clone()
try:
# Compute loss
loss_td = sac_loss(sampled_tensordict)
except KeyError:
raise Exception(f"wtf {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")
actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_qvalue"]
alpha_loss = loss_td["loss_alpha"]
# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(actor_params, max_norm = MAX_GRAD_NORM)
optimizer_actor.step()
# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
q_grad_norm = torch.nn.utils.clip_grad_norm_(critic_params, max_norm = MAX_GRAD_NORM)
optimizer_critic.step()
# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
alpha_grad_norm = torch.nn.utils.clip_grad_norm_([sac_loss.log_alpha], max_norm = MAX_GRAD_NORM)
optimizer_alpha.step()
losses[j] = loss_td.select(
"loss_actor", "loss_qvalue", "loss_alpha"
).detach()
alphas[j] = loss_td.select("alpha")
# Update qnet_target params
polyak_updater.step()
# Some other stuff I ripped out from https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
)
opening_banner = "-" * 10 + f" Batch {i + 1} " + "-" * 10
def get_mean(src, key):
return src.get(key).mean().item()
logger.info(opening_banner)
logger.info(f"Average Actor Loss: {get_mean(losses, 'loss_actor')}")
logger.info(f"Average Q Loss: {get_mean(losses, 'loss_qvalue')}")
logger.info(f"Average Alpha: {get_mean(alphas, 'alpha')} (Loss: {get_mean(losses, 'loss_alpha')})")
logger.info("-" * len(opening_banner))
ep_length = tensordict['next', 'step_count'][episode_end].to(dtype = torch.float64)
if ep_length.numel():
ep_lengths.append(ep_length.mean().item())
agent_terminated = torch.stack(
[
tensordict["next", "walker", "done"][:, agent_id, 0]
if tensordict["next", "walker", "done"][:, agent_id, 0].any()
else tensordict["next", "walker", "truncated"][:, agent_id, 0]
for agent_id in range(NUM_AGENTS)
],
dim = 1
)
train_reward = tensordict['next', 'walker', 'episode_reward'][agent_terminated]
if train_reward.numel():
train_rews.append(train_reward.mean().item())
if not ((i + 1) % (EVAL_INTERVAL * EXPLORATION_BATCHES)):
logger.info(
f"Mean Train Reward Across Past {EVAL_INTERVAL} Collections: " +
(
f"{sum(train_rews) / len(train_rews)}"
if len(train_rews)
else f"NA (Training starts @ {WARMUP_STEPS} steps)"
)
)
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
eval_rollout = eval_env.rollout(
MAX_EPISODE_STEPS,
policy_actor,
auto_cast_to_device=True,
break_when_any_done=True,
)
mean_eval_length = eval_rollout["next", "step_count"][-1].to(dtype = torch.float64).mean().item()
logger.info(f"Mean Eval Reward: {eval_rollout['next', 'walker', 'episode_reward'][-1].mean().item()}")
logger.info(f"Eval Length: {mean_eval_length}")
ep_reward_list = []
train_rews = []
eval_env.transform.dump()
collector.shutdown()
train_env.close()
Expected behavior
Loss values stay within ~ +/- 10^2
throughout training and do not increase to ~ +/- 10^x
where x >> 1.
System info
>>> import torchrl, numpy, sys
>>> print(f"TorchRL: {torchrl.__version__}\nNumPy: {numpy.__version__}\nPython3 Ver: {sys.version}\nPlatform: {sys.platform}")
TorchRL: 0.4.0
NumPy: 1.25.0
Python3 Ver: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
Platform: linux
> lsb_release -a
Distributor ID: Ubuntu
Description: Ubuntu 22.04.3 LTS
Release: 22.04
Codename: jammy
Reason and Possible fixes
Though the environment's observation space is not normalized and carries unbounded entries, the issue does not appear to entirely arise from the poor observation scaling, since adding a torchrl.envs.ObservationNorm
does not mitigate the issue.
Debugging reveals that unusually large and negative values for log_prob
are somehow being fed into the SACLoss
calculations from the reimplementation of torch.distributions.transforms.TanhTransform
.
rl/torchrl/modules/distributions/continuous.py
Lines 289 to 382 in 3e6cb84
Since this reimplementation does not change much from the original TanhTransform
, it is plausible that the reimplementation is NOT the root cause of the error. Nevertheless, replacing the reimplementation with an alternative variant gets rid of the issue altogether:
class CustomTanhTransform(D.transforms.TanhTransform):
def _inverse(self, y):
# from stable_baselines3's `common.distributions.TanhBijector`
"""
Inverse of Tanh
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
y = y.clamp(-1. + EPS, 1. - EPS)
return 0.5 * (y.log1p() - (-y).log1p())
def log_abs_det_jacobian(self, x, y):
# From PyTorch `TanhTransform`
'''
tl;dr log(1-tanh^2(x)) = log(sech^2(x))
= 2log(2/(e^x + e^(-x)))
= 2(log2 - log(e^x/(1 + e^(-2x)))
= 2(log2 - x - log(1 + e^(-2x)))
= 2(log2 - x - softplus(-2x))
'''
return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))
class TanhNormalStable(D.TransformedDistribution):
def __init__(self, loc, scale, event_dims = 1):
self._event_dims = event_dims
self._t = [
CustomTanhTransform()
]
self.update(loc, scale)
def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - D.utils._sum_rightmost(
transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim,
)
y = x
log_prob = log_prob + D.utils._sum_rightmost(
self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
)
log_prob = torch.clamp(log_prob, min = math.log10(EPS)) # <- **CLAMPING THIS SEEMS TO RESOLVE THE ISSUE**
return log_prob
def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
self.loc = loc
self.scale = scale
if (
hasattr(self, "base_dist")
and (self.base_dist.base_dist.loc.shape == self.loc.shape)
and (self.base_dist.base_dist.scale.shape == self.scale.shape)
):
self.base_dist.base_dist.loc = self.loc
self.base_dist.base_dist.scale = self.scale
else:
base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
super().__init__(base, self._t)
@property
def mode(self):
m = self.base_dist.base_dist.mean
for t in self.transforms:
m = t(m)
return m
although such a fix flies in the face of this comment from the PyTorch devs.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)