forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_metrics_and_callbacks.py
173 lines (156 loc) · 5.64 KB
/
custom_metrics_and_callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Example of using RLlib's debug callbacks.
Here we use callbacks to track the average CartPole pole angle magnitude as a
custom metric.
"""
from typing import Dict, Tuple
import argparse
import numpy as np
import os
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.",
)
parser.add_argument("--stop-iters", type=int, default=2000)
class MyCallbacks(DefaultCallbacks):
def on_episode_start(
self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs
):
# Make sure this episode has just been started (only initial obs
# logged so far).
assert episode.length == 0, (
"ERROR: `on_episode_start()` callback should be called right "
"after env reset!"
)
print("episode {} (env-idx={}) started.".format(episode.episode_id, env_index))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(
self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs
):
# Make sure this episode is ongoing.
assert episode.length > 0, (
"ERROR: `on_episode_step()` callback should not be called right "
"after env reset!"
)
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(
self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs
):
# Check if there are multiple episodes in a batch, i.e.
# "batch_mode": "truncate_episodes".
if worker.policy_config["batch_mode"] == "truncate_episodes":
# Make sure this episode is really done.
assert episode.batch_builder.policy_collectors["default_policy"].batches[
-1
]["dones"][-1], (
"ERROR: `on_episode_end()` should only be called "
"after episode is done!"
)
pole_angle = np.mean(episode.user_data["pole_angles"])
print(
"episode {} (env-idx={}) ended with length {} and pole "
"angles {}".format(
episode.episode_id, env_index, episode.length, pole_angle
)
)
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs):
print("returned sample batch of size {}".format(samples.count))
def on_train_result(self, *, trainer, result: dict, **kwargs):
print(
"trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]
)
)
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_learn_on_batch(
self, *, policy: Policy, train_batch: SampleBatch, result: dict, **kwargs
) -> None:
result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
print(
"policy.learn_on_batch() result: {} -> sum actions: {}".format(
policy, result["sum_actions_in_train_batch"]
)
)
def on_postprocess_trajectory(
self,
*,
worker: RolloutWorker,
episode: Episode,
agent_id: str,
policy_id: str,
policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, Tuple[Policy, SampleBatch]],
**kwargs
):
print("postprocessed {} steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
trials = tune.run(
"PG",
stop={
"training_iteration": args.stop_iters,
},
config={
"env": "CartPole-v0",
"num_envs_per_worker": 2,
"callbacks": MyCallbacks,
"framework": args.framework,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
},
).trials
# Verify episode-related custom metrics are there.
custom_metrics = trials[0].last_result["custom_metrics"]
print(custom_metrics)
assert "pole_angle_mean" in custom_metrics
assert "pole_angle_min" in custom_metrics
assert "pole_angle_max" in custom_metrics
assert "num_batches_mean" in custom_metrics
assert "callback_ok" in trials[0].last_result
# Verify `on_learn_on_batch` custom metrics are there (per policy).
if args.framework == "torch":
info_custom_metrics = custom_metrics["default_policy"]
print(info_custom_metrics)
assert "sum_actions_in_train_batch" in info_custom_metrics