Skip to content

[Algorithm] QGNN mixer #1420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion examples/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import EGreedyModule, QValueModule, SafeSequential
from torchrl.modules.models.multiagent import MultiAgentMLP, QMixer, VDNMixer
from torchrl.modules.models.multiagent import MultiAgentMLP, QGNNMixer, QMixer, VDNMixer
from torchrl.objectives import SoftUpdate, ValueEstimators
from torchrl.objectives.multiagent.qmixer import QMixerLoss
from utils.logging import init_logging, log_evaluation, log_training
Expand Down Expand Up @@ -129,6 +129,30 @@ def train(cfg: "DictConfig"): # noqa: F821
in_keys=[("agents", "chosen_action_value")],
out_keys=["chosen_action_value"],
)
elif cfg.loss.mixer_type == "qgnn":
mixer = TensorDictModule(
module=QGNNMixer(
use_state=False,
n_agents=env.n_agents,
device=cfg.train.device,
),
in_keys=[("agents", "chosen_action_value")],
out_keys=["chosen_action_value"],
)
elif cfg.loss.mixer_type == "qgnn-state":
mixer = TensorDictModule(
module=QGNNMixer(
state_shape=env.unbatched_observation_spec[
"agents", "observation"
].shape,
mixing_embed_dim=8,
use_state=True,
n_agents=env.n_agents,
device=cfg.train.device,
),
in_keys=[("agents", "chosen_action_value"), ("agents", "observation")],
out_keys=["chosen_action_value"],
)
else:
raise ValueError("Mixer type not in the example")

Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/qmix_vdn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ buffer:
memory_size: ???

loss:
mixer_type: "qmix" # or "vdn"
mixer_type: "qmix" # choose from "qmix", "vdn", "qgnn", "qgnn-state"
gamma: 0.9
tau: 0.005 # For target net

Expand Down
4 changes: 3 additions & 1 deletion torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .model_based import DreamerActor, ObsDecoder, ObsEncoder, RSSMPosterior, RSSMPrior
from .models import (
AbsLinear,
Conv2dNet,
Conv3dNet,
ConvNet,
Expand All @@ -18,9 +19,10 @@
DistributionalDQNnet,
DTActor,
DuelingCnnDQNet,
HyperLinear,
LSTMNet,
MLP,
OnlineDTActor,
)
from .multiagent import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer
from .multiagent import MultiAgentConvNet, MultiAgentMLP, QGNNMixer, QMixer, VDNMixer
from .utils import Squeeze2dLayer, SqueezeLayer
54 changes: 54 additions & 0 deletions torchrl/modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,60 @@ def forward(
return self._lstm(input, hidden0_in, hidden1_in)


class HyperLinear(nn.Module):
"""Missing."""

def __init__(self, in_dim, out_dim, pos=True, **kwargs):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.pos = pos
self.w = None
self.b = None

def num_params(self):
return self.in_dim * self.out_dim + self.out_dim

def update_params(self, params):
# params: b x (in_dim * out_dim + out_dim)
assert params.shape[1] == self.in_dim * self.out_dim + self.out_dim
batch = params.shape[0]
self.w = params[:, : self.in_dim * self.out_dim].view(
batch, self.in_dim, self.out_dim
)
self.b = params[:, self.in_dim * self.out_dim :].view(batch, self.out_dim)
if self.pos:
self.w = torch.abs(self.w)

def forward(self, x):
# x: b x in_dim OR b x n x in_dim
w = self.w
b = self.b
assert x.shape[0] == w.shape[0]
assert x.shape[-1] == w.shape[1]
squeeze_output = False
if x.dim() == 2:
squeeze_output = True
x = x.unsqueeze(1)
if b.dim() == 3:
b = b.squeeze(1)
xw = torch.bmm(x, w)
out = xw + b[:, None]
if squeeze_output:
out = out.squeeze(1)
return out


class AbsLinear(nn.Linear):
"""Missing."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, input):
return nn.functional.linear(input, torch.abs(self.weight), self.bias)


class OnlineDTActor(nn.Module):
"""Online Decision Transformer Actor class.

Expand Down
101 changes: 100 additions & 1 deletion torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from torchrl.data.utils import DEVICE_TYPING

from torchrl.modules.models import ConvNet, MLP
from torchrl.modules.models import AbsLinear, ConvNet, HyperLinear, MLP


class MultiAgentMLP(nn.Module):
Expand Down Expand Up @@ -820,3 +820,102 @@ def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor):
# Reshape and return
q_tot = y.view(*bs, 1)
return q_tot


class QGNNMixer(Mixer):
"""QGNN Mixer.

From https://arxiv.org/abs/2205.13005

"""

def __init__(
self,
n_agents: int,
device,
mixing_embed_dim=8,
state_shape=None,
use_state=False,
):
super().__init__(
needs_state=use_state,
state_shape=state_shape if use_state else torch.Size([]),
n_agents=n_agents,
device=device,
)

self.use_state = use_state
self.embed_dim = mixing_embed_dim
self.state_dim = int(np.prod(state_shape)) if self.use_state else None

self.psi_hyper = MLP(
in_features=1,
out_features=self.embed_dim,
depth=3,
num_cells=self.embed_dim,
activation_class=nn.ReLU,
activate_last_layer=False,
layer_class=HyperLinear if self.use_state else AbsLinear,
layer_kwargs={"pos": True} if self.use_state else {},
device=device,
)

self.phi_hyper = MLP(
in_features=self.embed_dim,
out_features=1,
depth=3,
num_cells=self.embed_dim,
activation_class=nn.ReLU,
activate_last_layer=False,
layer_class=HyperLinear if self.use_state else AbsLinear,
layer_kwargs={"pos": True} if self.use_state else {},
device=device,
)

if self.use_state:
self.psi_param_net = MLP(
in_features=self.state_dim,
out_features=self.num_params(self.psi_hyper),
depth=2,
num_cells=self.state_dim,
activation_class=nn.Mish,
activate_last_layer=False,
device=device,
)
self.phi_param_net = MLP(
in_features=self.state_dim,
out_features=self.num_params(self.phi_hyper),
depth=2,
num_cells=self.state_dim,
activation_class=nn.Mish,
activate_last_layer=False,
device=device,
)

def mix(self, chosen_action_value: torch.Tensor, state: torch.Tensor):
if self.use_state:
state = state.view(-1, self.state_dim)
psi_params = self.psi_param_net(state)
phi_params = self.phi_param_net(state)
self.update_params(self.psi_hyper, psi_params)
self.update_params(self.phi_hyper, phi_params)
psi_out = self.psi_hyper(chosen_action_value)
summed = psi_out.sum(dim=-2)
phi_out = self.phi_hyper(summed)
return phi_out

def num_params(self, net):
num_params = 0
for layer in net:
if isinstance(layer, HyperLinear):
num_params += layer.num_params()
return num_params

def update_params(self, net, params):
i = 0
for layer in net:
if isinstance(layer, HyperLinear):
layer_num_params = layer.num_params()
layer_params = params[:, i : i + layer_num_params]
i += layer_num_params
layer.update_params(layer_params)