-
Notifications
You must be signed in to change notification settings - Fork 0
/
02_train_vs_eval.py
75 lines (55 loc) · 2.4 KB
/
02_train_vs_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import matplotlib.pyplot as plt
from cmx import doc
from matplotlib import ticker
from ml_logger import ML_Logger
doc @ """
# Comparing Two Learning Curves Side-by-side
Here we compare the training performance versus the performance
on the evaluation domain.
We show the training performance in gray, to accentuate the
evaluation curve.
"""
with doc @ """Initialize the loader""":
loader = ML_Logger(root=os.getcwd(), prefix="data/walker-walk/curl")
with doc @ """Check all the files""":
files = loader.glob(query="**/metrics.pkl", wd=".", recursive=True)
doc.print(files)
with doc @ """A Single Time Series""":
def group(xKey="step", yKey="train/episode_reward/mean", color=None, bin=10, label=None, dropna=False):
avg, top, bottom, step = loader.read_metrics(f"{yKey}@mean", f"{yKey}@84%", f"{yKey}@16%", x_key=f"{xKey}@mean",
path="**/metrics.pkl", bin_size=bin, dropna=dropna)
plt.plot(step, avg, color=color, label=label)
plt.fill_between(step, bottom, top, alpha=0.15, color=color)
return avg
with doc @ "Step 2: Plot", doc.table().figure_row() as r:
colors = ['#49b8ff', '#444444', '#ff7575', '#66c56c', '#f4b247']
avg = group(yKey="episode_reward/mean", bin=None, color=colors[0], label="Eval")
group(yKey="train/episode_reward/mean", color=colors[1], label="Train")
plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f"{int(x / 1000)}k" if x else "0"))
plt.legend()
plt.title("Walker-walk")
plt.xlabel("Steps")
plt.ylabel("Return")
r.savefig(f"figures/train_vs_eval.png", title="Train VS Eval", dpi=300, zoom="20%")
plt.close()
doc @ """
## Where does the empty cuts come from?
These cuts are places where the `avg` is `NaN`. You can just filter this out
in the `group` function.
"""
with doc:
doc.print(avg)
doc @ """
## How to fix this?
You can turn on the `dropna` flag, which is OFF by default.
"""
with doc, doc.table().figure_row() as r:
avg = group(yKey="episode_reward/mean", bin=None, color=colors[0], label="Eval", dropna=True)
plt.gca().xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: f"{int(x / 1000)}k" if x else "0"))
plt.legend()
plt.title("Walker-walk")
plt.xlabel("Steps")
plt.ylabel("Return")
r.savefig(f"figures/train_vs_eval_dropna.png", title="Train VS Eval", dpi=300, zoom="20%")
doc.flush()