Skip to content

[NOMRG] Example: GNN from_pos #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarl/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
defaults:
- experiment: base_experiment
- algorithm: ???
- task: ???
- model: layers/mlp
- algorithm: ippo
- task: vmas/navigation_pos
- model: layers/gnn
- model@critic_model: layers/mlp
- _self_

Expand Down
8 changes: 4 additions & 4 deletions benchmarl/conf/model/layers/gnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ gnn_class: torch_geometric.nn.conv.GraphConv
gnn_kwargs:
aggr: "add"

position_key: null
pos_features: 0
velocity_key: null
vel_features: 0
position_key: pos
pos_features: 2
velocity_key: vel
vel_features: 2

exclude_pos_from_node_features: False
edge_radius: null
11 changes: 11 additions & 0 deletions benchmarl/conf/task/vmas/navigation_pos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@


max_steps: 100
n_agents: 3
collisions: True
agents_with_same_goal: 1
split_goals: False
observe_all_goals: False
shared_rew: False
lidar_range: 0.35
agent_radius: 0.1
8 changes: 7 additions & 1 deletion benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from benchmarl.environments.common import Task
from benchmarl.utils import DEVICE_TYPING
from .navigation_pos import NavigationScenario as navigation_pos_scenario


class VmasTask(Task):
Expand Down Expand Up @@ -44,6 +45,7 @@ class VmasTask(Task):
SIMPLE_SPREAD = None
SIMPLE_TAG = None
SIMPLE_WORLD_COMM = None
NAVIGATION_POS = None

def get_env_fun(
self,
Expand All @@ -53,8 +55,12 @@ def get_env_fun(
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
if self is VmasTask.NAVIGATION_POS:
scenario = navigation_pos_scenario()
else:
scenario = self.name.lower()
return lambda: VmasEnv(
scenario=self.name.lower(),
scenario=scenario,
num_envs=num_envs,
continuous_actions=continuous_actions,
seed=seed,
Expand Down
61 changes: 61 additions & 0 deletions benchmarl/environments/vmas/navigation_pos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING

import torch
from vmas import render_interactively
from vmas.simulator.core import Agent


@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
collisions: bool = MISSING
agents_with_same_goal: int = MISSING
observe_all_goals: bool = MISSING
shared_rew: bool = MISSING
split_goals: bool = MISSING
lidar_range: float = MISSING
agent_radius: float = MISSING


from vmas.scenarios.navigation import Scenario


def observation(self, agent: Agent):
goal_poses = []
if self.observe_all_goals:
for a in self.world.agents:
goal_poses.append(agent.state.pos - a.goal.state.pos)
else:
goal_poses.append(agent.state.pos - agent.goal.state.pos)

return {
"obs": torch.cat(
goal_poses
+ (
[agent.sensors[0]._max_range - agent.sensors[0].measure()]
if self.collisions
else []
),
dim=-1,
),
"pos": agent.state.pos,
"vel": agent.state.vel,
}


Scenario.observation = observation
NavigationScenario = Scenario


if __name__ == "__main__":
render_interactively(
NavigationScenario(),
control_two_agents=True,
)
Loading