forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiterated_prisoners_dilemma_env.py
86 lines (73 loc) · 2.36 KB
/
iterated_prisoners_dilemma_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
##########
# Contribution by the Center on Long-Term Risk:
# https://github.com/longtermrisk/marltoolbox
##########
import argparse
import os
import ray
from ray import tune
from ray.rllib.algorithms.pg import PG
from ray.rllib.examples.env.matrix_sequential_social_dilemma import (
IteratedPrisonersDilemma,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
parser.add_argument("--stop-iters", type=int, default=200)
def main(debug, stop_iters=200, framework="tf"):
train_n_replicates = 1 if debug else 1
seeds = list(range(train_n_replicates))
ray.init(num_cpus=os.cpu_count(), num_gpus=0, local_mode=debug)
rllib_config, stop_config = get_rllib_config(seeds, debug, stop_iters, framework)
tune_analysis = tune.run(
PG,
config=rllib_config,
stop=stop_config,
checkpoint_freq=0,
checkpoint_at_end=True,
name="PG_IPD",
)
ray.shutdown()
return tune_analysis
def get_rllib_config(seeds, debug=False, stop_iters=200, framework="tf"):
stop_config = {
"training_iteration": 2 if debug else stop_iters,
}
env_config = {
"players_ids": ["player_row", "player_col"],
"max_steps": 20,
"get_additional_info": True,
}
rllib_config = {
"env": IteratedPrisonersDilemma,
"env_config": env_config,
"multiagent": {
"policies": {
env_config["players_ids"][0]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
env_config["players_ids"][1]: (
None,
IteratedPrisonersDilemma.OBSERVATION_SPACE,
IteratedPrisonersDilemma.ACTION_SPACE,
{},
),
},
"policy_mapping_fn": lambda agent_id, **kwargs: agent_id,
},
"seed": tune.grid_search(seeds),
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"framework": framework,
}
return rllib_config, stop_config
if __name__ == "__main__":
debug_mode = True
args = parser.parse_args()
main(debug_mode, args.stop_iters, args.framework)