forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib-contrib] MADDPG. (ray-project#36628)
- Loading branch information
Showing
9 changed files
with
963 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# MADDPG (Multi-Agent Deep Deterministic Policy Gradient) | ||
|
||
[MADDPG](https://arxiv.org/abs/1706.02275) is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check justinkterry/maddpg-rllib for examples and more information. Note that the implementation here is based on OpenAI’s, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it. | ||
|
||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-maddpg python=3.10 | ||
conda activate rllib-maddpg | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[MADDPG Example]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# The two-step game from QMIX: https://arxiv.org/pdf/1803.11485.pdf | ||
|
||
import argparse | ||
import logging | ||
|
||
from gymnasium.spaces import Dict, Discrete, MultiDiscrete, Tuple | ||
from rllib_maddpg.maddpg import MADDPG, MADDPGConfig | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.env.multi_agent_env import ENV_STATE | ||
from ray.rllib.examples.env.two_step_game import TwoStepGame | ||
from ray.rllib.policy.policy import PolicySpec | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
from ray.tune import register_env | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--mixer", | ||
type=str, | ||
default="qmix", | ||
choices=["qmix", "vdn", "none"], | ||
help="The mixer model to use.", | ||
) | ||
parser.add_argument( | ||
"--run-as-test", | ||
action="store_true", | ||
) | ||
|
||
parser.add_argument( | ||
"--stop-timesteps", type=int, default=70000, help="Number of timesteps to train." | ||
) | ||
parser.add_argument( | ||
"--stop-reward", type=float, default=8.0, help="Reward at which we stop training." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
ray.init() | ||
|
||
grouping = { | ||
"group_1": [0, 1], | ||
} | ||
obs_space = Tuple( | ||
[ | ||
Dict( | ||
{ | ||
"obs": MultiDiscrete([2, 2, 2, 3]), | ||
ENV_STATE: MultiDiscrete([2, 2, 2]), | ||
} | ||
), | ||
Dict( | ||
{ | ||
"obs": MultiDiscrete([2, 2, 2, 3]), | ||
ENV_STATE: MultiDiscrete([2, 2, 2]), | ||
} | ||
), | ||
] | ||
) | ||
act_space = Tuple( | ||
[ | ||
TwoStepGame.action_space, | ||
TwoStepGame.action_space, | ||
] | ||
) | ||
register_env( | ||
"grouped_twostep", | ||
lambda config: TwoStepGame(config).with_agent_groups( | ||
grouping, obs_space=obs_space, act_space=act_space | ||
), | ||
) | ||
|
||
config = ( | ||
MADDPGConfig() | ||
.environment(TwoStepGame) | ||
.framework("torch") | ||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. | ||
.resources() | ||
) | ||
|
||
obs_space = Discrete(6) | ||
act_space = TwoStepGame.action_space | ||
( | ||
config.framework("tf") | ||
.environment(env_config={"actions_are_logits": True}) | ||
.training(num_steps_sampled_before_learning_starts=100) | ||
.multi_agent( | ||
policies={ | ||
"pol1": PolicySpec( | ||
observation_space=obs_space, | ||
action_space=act_space, | ||
config=config.overrides(agent_id=0), | ||
), | ||
"pol2": PolicySpec( | ||
observation_space=obs_space, | ||
action_space=act_space, | ||
config=config.overrides(agent_id=1), | ||
), | ||
}, | ||
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol2" | ||
if agent_id | ||
else "pol1", | ||
) | ||
) | ||
|
||
stop = { | ||
"episode_reward_mean": args.stop_reward, | ||
"timesteps_total": args.stop_timesteps, | ||
} | ||
|
||
results = tune.Tuner( | ||
MADDPG, | ||
run_config=air.RunConfig(stop=stop, verbose=2), | ||
param_space=config, | ||
).fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, args.stop_reward) | ||
|
||
ray.shutdown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-maddpg" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.0"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tensorflow==2.11.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rllib_maddpg.maddpg.maddpg import MADDPG, MADDPGConfig | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = ["MADDPGConfig", "MADDPG"] | ||
|
||
register_trainable("rllib-contrib-maddpg", MADDPG) |
Oops, something went wrong.