-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Multiagent simplerl #5066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiagent simplerl #5066
Changes from all commits
59825a1
2accf19
a6ffbd8
3e26bc3
5126254
fea6d53
7447d88
2dd982d
278ecf2
38dc560
82e5e99
3c88d65
51aba67
bfd1428
72f0370
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
from mlagents.trainers.tests.simple_test_envs import ( | ||
SimpleEnvironment, | ||
MultiAgentEnvironment, | ||
MemoryEnvironment, | ||
RecordEnvironment, | ||
) | ||
|
@@ -27,7 +28,11 @@ | |
ActionSpecProto, | ||
) | ||
|
||
from mlagents.trainers.tests.dummy_config import ppo_dummy_config, sac_dummy_config | ||
from mlagents.trainers.tests.dummy_config import ( | ||
ppo_dummy_config, | ||
sac_dummy_config, | ||
coma_dummy_config, | ||
) | ||
from mlagents.trainers.tests.check_env_trains import ( | ||
check_environment_trains, | ||
default_reward_processor, | ||
|
@@ -37,11 +42,83 @@ | |
|
||
PPO_TORCH_CONFIG = ppo_dummy_config() | ||
SAC_TORCH_CONFIG = sac_dummy_config() | ||
COMA_TORCH_CONFIG = coma_dummy_config() | ||
|
||
# tests in this file won't be tested on GPU machine | ||
pytestmark = pytest.mark.check_environment_trains | ||
|
||
|
||
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be tested for combinations of rank 1, 2 and 3 observations and with and LSTM config? (To make sure it does not crash at least) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some tests for LSTM, variable length obs, and visual |
||
def test_simple_coma(action_sizes): | ||
env = MultiAgentEnvironment([BRAIN_NAME], action_sizes=action_sizes, num_agents=2) | ||
config = attr.evolve(COMA_TORCH_CONFIG) | ||
check_environment_trains(env, {BRAIN_NAME: config}) | ||
|
||
|
||
@pytest.mark.parametrize("num_visual", [1, 2]) | ||
def test_visual_coma(num_visual): | ||
env = MultiAgentEnvironment( | ||
[BRAIN_NAME], action_sizes=(0, 1), num_agents=2, num_visual=num_visual | ||
) | ||
new_hyperparams = attr.evolve( | ||
COMA_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4 | ||
) | ||
config = attr.evolve(COMA_TORCH_CONFIG, hyperparameters=new_hyperparams) | ||
check_environment_trains(env, {BRAIN_NAME: config}) | ||
|
||
|
||
@pytest.mark.parametrize("num_var_len", [1, 2]) | ||
@pytest.mark.parametrize("num_vector", [0, 1]) | ||
@pytest.mark.parametrize("num_vis", [0, 1]) | ||
def test_var_len_obs_coma(num_vis, num_vector, num_var_len): | ||
env = MultiAgentEnvironment( | ||
[BRAIN_NAME], | ||
action_sizes=(0, 1), | ||
num_visual=num_vis, | ||
num_vector=num_vector, | ||
num_var_len=num_var_len, | ||
step_size=0.2, | ||
num_agents=2, | ||
) | ||
new_hyperparams = attr.evolve( | ||
COMA_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4 | ||
) | ||
config = attr.evolve(COMA_TORCH_CONFIG, hyperparameters=new_hyperparams) | ||
check_environment_trains(env, {BRAIN_NAME: config}) | ||
|
||
|
||
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) | ||
@pytest.mark.parametrize("is_multiagent", [True, False]) | ||
def test_recurrent_coma(action_sizes, is_multiagent): | ||
if is_multiagent: | ||
# This is not a recurrent environment, just check if LSTM doesn't crash | ||
env = MultiAgentEnvironment( | ||
[BRAIN_NAME], action_sizes=action_sizes, num_agents=2 | ||
) | ||
else: | ||
# Actually test LSTM here | ||
env = MemoryEnvironment([BRAIN_NAME], action_sizes=action_sizes) | ||
new_network_settings = attr.evolve( | ||
COMA_TORCH_CONFIG.network_settings, | ||
memory=NetworkSettings.MemorySettings(memory_size=16), | ||
) | ||
new_hyperparams = attr.evolve( | ||
COMA_TORCH_CONFIG.hyperparameters, | ||
learning_rate=1.0e-3, | ||
batch_size=64, | ||
buffer_size=128, | ||
) | ||
config = attr.evolve( | ||
COMA_TORCH_CONFIG, | ||
hyperparameters=new_hyperparams, | ||
network_settings=new_network_settings, | ||
max_steps=500 if is_multiagent else 6000, | ||
) | ||
check_environment_trains( | ||
env, {BRAIN_NAME: config}, success_threshold=None if is_multiagent else 0.9 | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) | ||
def test_simple_ppo(action_sizes): | ||
env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_sizes) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand nothing this class does. Please add comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it is a pretty horrible thing