Skip to content

Commit b04395c

Browse files
author
Ervin T
authored
Remove TrainerMetrics and add CSVWriter using new StatsWriter API (#3108)
1 parent 47625be commit b04395c

File tree

10 files changed

+145
-229
lines changed

10 files changed

+145
-229
lines changed

ml-agents/mlagents/trainers/learn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mlagents.trainers.exception import TrainerError
1818
from mlagents.trainers.meta_curriculum import MetaCurriculum
1919
from mlagents.trainers.trainer_util import load_config, TrainerFactory
20-
from mlagents.trainers.stats import TensorboardWriter, StatsReporter
20+
from mlagents.trainers.stats import TensorboardWriter, CSVWriter, StatsReporter
2121
from mlagents_envs.environment import UnityEnvironment
2222
from mlagents.trainers.sampler_class import SamplerManager
2323
from mlagents.trainers.exception import SamplerException
@@ -250,9 +250,15 @@ def run_training(
250250
trainer_config = load_config(trainer_config_path)
251251
port = options.base_port + (sub_id * options.num_envs)
252252

253-
# Configure Tensorboard Writers and StatsReporter
253+
# Configure CSV, Tensorboard Writers and StatsReporter
254+
# We assume reward and episode length are needed in the CSV.
255+
csv_writer = CSVWriter(
256+
summaries_dir,
257+
required_fields=["Environment/Cumulative Reward", "Environment/Episode Length"],
258+
)
254259
tb_writer = TensorboardWriter(summaries_dir)
255260
StatsReporter.add_writer(tb_writer)
261+
StatsReporter.add_writer(csv_writer)
256262

257263
if options.env_path is None:
258264
port = 5004 # This is the in Editor Training Port

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ def update_policy(self):
177177
The reward signal generators must be updated in this method at their own pace.
178178
"""
179179
buffer_length = self.update_buffer.num_experiences
180-
self.trainer_metrics.start_policy_update_timer(
181-
number_experiences=buffer_length,
182-
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
183-
)
184180
self.cumulative_returns_since_policy_update.clear()
185181

186182
# Make sure batch_size is a multiple of sequence length. During training, we
@@ -221,7 +217,6 @@ def update_policy(self):
221217
for stat, val in update_stats.items():
222218
self.stats_reporter.add_stat(stat, val)
223219
self.clear_update_buffer()
224-
self.trainer_metrics.end_policy_update()
225220

226221

227222
def discount_rewards(r, gamma=0.99, value_next=0.0):

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,8 @@ def update_policy(self) -> None:
207207
If reward_signal_train_interval is met, update the reward signals from the buffer.
208208
"""
209209
if self.step % self.train_interval == 0:
210-
self.trainer_metrics.start_policy_update_timer(
211-
number_experiences=self.update_buffer.num_experiences,
212-
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
213-
)
214210
self.update_sac_policy()
215211
self.update_reward_signals()
216-
self.trainer_metrics.end_policy_update()
217212

218213
def update_sac_policy(self) -> None:
219214
"""

ml-agents/mlagents/trainers/stats.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,28 @@
22
from typing import List, Dict, NamedTuple
33
import numpy as np
44
import abc
5+
import csv
56
import os
67

78
from mlagents.tf_utils import tf
89

910

11+
class StatsSummary(NamedTuple):
12+
mean: float
13+
std: float
14+
num: int
15+
16+
1017
class StatsWriter(abc.ABC):
1118
"""
1219
A StatsWriter abstract class. A StatsWriter takes in a category, key, scalar value, and step
1320
and writes it out by some method.
1421
"""
1522

1623
@abc.abstractmethod
17-
def write_stats(self, category: str, key: str, value: float, step: int) -> None:
24+
def write_stats(
25+
self, category: str, values: Dict[str, StatsSummary], step: int
26+
) -> None:
1827
pass
1928

2029
@abc.abstractmethod
@@ -24,15 +33,23 @@ def write_text(self, category: str, text: str, step: int) -> None:
2433

2534
class TensorboardWriter(StatsWriter):
2635
def __init__(self, base_dir: str):
36+
"""
37+
A StatsWriter that writes to a Tensorboard summary.
38+
:param base_dir: The directory within which to place all the summaries. Tensorboard files will be written to a
39+
{base_dir}/{category} directory.
40+
"""
2741
self.summary_writers: Dict[str, tf.summary.FileWriter] = {}
2842
self.base_dir: str = base_dir
2943

30-
def write_stats(self, category: str, key: str, value: float, step: int) -> None:
44+
def write_stats(
45+
self, category: str, values: Dict[str, StatsSummary], step: int
46+
) -> None:
3147
self._maybe_create_summary_writer(category)
32-
summary = tf.Summary()
33-
summary.value.add(tag="{}".format(key), simple_value=value)
34-
self.summary_writers[category].add_summary(summary, step)
35-
self.summary_writers[category].flush()
48+
for key, value in values.items():
49+
summary = tf.Summary()
50+
summary.value.add(tag="{}".format(key), simple_value=value.mean)
51+
self.summary_writers[category].add_summary(summary, step)
52+
self.summary_writers[category].flush()
3653

3754
def _maybe_create_summary_writer(self, category: str) -> None:
3855
if category not in self.summary_writers:
@@ -47,10 +64,59 @@ def write_text(self, category: str, text: str, step: int) -> None:
4764
self.summary_writers[category].add_summary(text, step)
4865

4966

50-
class StatsSummary(NamedTuple):
51-
mean: float
52-
std: float
53-
num: int
67+
class CSVWriter(StatsWriter):
68+
def __init__(self, base_dir: str, required_fields: List[str] = None):
69+
"""
70+
A StatsWriter that writes to a Tensorboard summary.
71+
:param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv.
72+
:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
73+
them.
74+
"""
75+
# We need to keep track of the fields in the CSV, as all rows need the same fields.
76+
self.csv_fields: Dict[str, List[str]] = {}
77+
self.required_fields = required_fields if required_fields else []
78+
self.base_dir: str = base_dir
79+
80+
def write_stats(
81+
self, category: str, values: Dict[str, StatsSummary], step: int
82+
) -> None:
83+
if self._maybe_create_csv_file(category, list(values.keys())):
84+
row = [str(step)]
85+
# Only record the stats that showed up in the first valid row
86+
for key in self.csv_fields[category]:
87+
_val = values.get(key, None)
88+
row.append(str(_val.mean) if _val else "None")
89+
with open(self._get_filepath(category), "a") as file:
90+
writer = csv.writer(file)
91+
writer.writerow(row)
92+
93+
def _maybe_create_csv_file(self, category: str, keys: List[str]) -> bool:
94+
"""
95+
If no CSV file exists and the keys have the required values,
96+
make the CSV file and write hte title row.
97+
Returns True if there is now (or already is) a valid CSV file.
98+
"""
99+
if category not in self.csv_fields:
100+
summary_dir = self.base_dir
101+
os.makedirs(summary_dir, exist_ok=True)
102+
# Only store if the row contains the required fields
103+
if all(item in keys for item in self.required_fields):
104+
self.csv_fields[category] = keys
105+
with open(self._get_filepath(category), "w") as file:
106+
title_row = ["Steps"]
107+
title_row.extend(keys)
108+
writer = csv.writer(file)
109+
writer.writerow(title_row)
110+
return True
111+
return False
112+
return True
113+
114+
def _get_filepath(self, category: str) -> str:
115+
file_dir = os.path.join(self.base_dir, category + ".csv")
116+
return file_dir
117+
118+
def write_text(self, category: str, text: str, step: int) -> None:
119+
pass
54120

55121

56122
class StatsReporter:
@@ -87,11 +153,13 @@ def write_stats(self, step: int) -> None:
87153
:param category: The category which to write out the stats.
88154
:param step: Training step which to write these stats as.
89155
"""
156+
values: Dict[str, StatsSummary] = {}
90157
for key in StatsReporter.stats_dict[self.category]:
91158
if len(StatsReporter.stats_dict[self.category][key]) > 0:
92-
stat_mean = float(np.mean(StatsReporter.stats_dict[self.category][key]))
93-
for writer in StatsReporter.writers:
94-
writer.write_stats(self.category, key, stat_mean, step)
159+
stat_summary = self.get_stats_summaries(key)
160+
values[key] = stat_summary
161+
for writer in StatsReporter.writers:
162+
writer.write_stats(self.category, values, step)
95163
del StatsReporter.stats_dict[self.category]
96164

97165
def write_text(self, text: str, step: int) -> None:

ml-agents/mlagents/trainers/tests/test_stats.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22
import os
33
import pytest
44
import tempfile
5+
import csv
56

6-
from mlagents.trainers.stats import StatsReporter, TensorboardWriter
7+
from mlagents.trainers.stats import (
8+
StatsReporter,
9+
TensorboardWriter,
10+
CSVWriter,
11+
StatsSummary,
12+
)
713

814

915
def test_stat_reporter_add_summary_write():
@@ -35,8 +41,12 @@ def test_stat_reporter_add_summary_write():
3541
# Test write_stats
3642
step = 10
3743
statsreporter1.write_stats(step)
38-
mock_writer1.write_stats.assert_called_once_with("category1", "key1", 4.5, step)
39-
mock_writer2.write_stats.assert_called_once_with("category1", "key1", 4.5, step)
44+
mock_writer1.write_stats.assert_called_once_with(
45+
"category1", {"key1": statssummary1}, step
46+
)
47+
mock_writer2.write_stats.assert_called_once_with(
48+
"category1", {"key1": statssummary1}, step
49+
)
4050

4151

4252
def test_stat_reporter_text():
@@ -61,7 +71,8 @@ def test_tensorboard_writer(mock_filewriter, mock_summary):
6171
category = "category1"
6272
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
6373
tb_writer = TensorboardWriter(base_dir)
64-
tb_writer.write_stats("category1", "key1", 1.0, 10)
74+
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
75+
tb_writer.write_stats("category1", {"key1": statssummary1}, 10)
6576

6677
# Test that the filewriter has been created and the directory has been created.
6778
filewriter_dir = "{basedir}/{category}".format(
@@ -78,3 +89,43 @@ def test_tensorboard_writer(mock_filewriter, mock_summary):
7889
mock_summary.return_value, 10
7990
)
8091
mock_filewriter.return_value.flush.assert_called_once()
92+
93+
94+
def test_csv_writer():
95+
# Test write_stats
96+
category = "category1"
97+
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
98+
csv_writer = CSVWriter(base_dir, required_fields=["key1", "key2"])
99+
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
100+
csv_writer.write_stats("category1", {"key1": statssummary1}, 10)
101+
102+
# Test that the filewriter has been created and the directory has been created.
103+
filewriter_dir = "{basedir}/{category}.csv".format(
104+
basedir=base_dir, category=category
105+
)
106+
# The required keys weren't in the stats
107+
assert not os.path.exists(filewriter_dir)
108+
109+
csv_writer.write_stats(
110+
"category1", {"key1": statssummary1, "key2": statssummary1}, 10
111+
)
112+
csv_writer.write_stats(
113+
"category1", {"key1": statssummary1, "key2": statssummary1}, 20
114+
)
115+
116+
# The required keys were in the stats
117+
assert os.path.exists(filewriter_dir)
118+
119+
with open(filewriter_dir) as csv_file:
120+
csv_reader = csv.reader(csv_file, delimiter=",")
121+
line_count = 0
122+
for row in csv_reader:
123+
if line_count == 0:
124+
assert "key1" in row
125+
assert "key2" in row
126+
assert "Steps" in row
127+
line_count += 1
128+
else:
129+
assert len(row) == 3
130+
line_count += 1
131+
assert line_count == 3

ml-agents/mlagents/trainers/tests/test_trainer_metrics.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

ml-agents/mlagents/trainers/tests/test_trainer_util.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import mlagents.trainers.trainer_util as trainer_util
77
from mlagents.trainers.trainer_util import load_config, _load_config
8-
from mlagents.trainers.trainer_metrics import TrainerMetrics
98
from mlagents.trainers.ppo.trainer import PPOTrainer
109
from mlagents.trainers.exception import TrainerConfigError
1110
from mlagents.trainers.brain import BrainParameters
@@ -119,7 +118,6 @@ def mock_constructor(
119118
run_id,
120119
multi_gpu,
121120
):
122-
self.trainer_metrics = TrainerMetrics("", "")
123121
assert brain == brain_params_mock
124122
assert trainer_parameters == expected_config
125123
assert reward_buff_cap == expected_reward_buff_cap
@@ -178,7 +176,6 @@ def mock_constructor(
178176
run_id,
179177
multi_gpu,
180178
):
181-
self.trainer_metrics = TrainerMetrics("", "")
182179
assert brain == brain_params_mock
183180
assert trainer_parameters == expected_config
184181
assert reward_buff_cap == expected_reward_buff_cap

ml-agents/mlagents/trainers/trainer.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from mlagents_envs.exception import UnityException
1010
from mlagents_envs.timers import set_gauge
11-
from mlagents.trainers.trainer_metrics import TrainerMetrics
1211
from mlagents.trainers.tf_policy import TFPolicy
1312
from mlagents.trainers.stats import StatsReporter
1413
from mlagents.trainers.trajectory import Trajectory
@@ -52,9 +51,6 @@ def __init__(
5251
self.stats_reporter = StatsReporter(self.summary_path)
5352
self.cumulative_returns_since_policy_update: List[float] = []
5453
self.is_training = training
55-
self.trainer_metrics = TrainerMetrics(
56-
path=self.summary_path + ".csv", brain_name=self.brain_name
57-
)
5854
self._reward_buffer: Deque[float] = deque(maxlen=reward_buff_cap)
5955
self.policy: TFPolicy = None # type: ignore # this will always get set
6056
self.step: int = 0
@@ -170,13 +166,6 @@ def export_model(self) -> None:
170166
"""
171167
self.policy.export_model()
172168

173-
def write_training_metrics(self) -> None:
174-
"""
175-
Write training metrics to a CSV file
176-
:return:
177-
"""
178-
self.trainer_metrics.write_training_metrics()
179-
180169
def write_summary(self, global_step: int, delta_train_start: float) -> None:
181170
"""
182171
Saves training statistics to Tensorboard.

0 commit comments

Comments
 (0)