Skip to content
This repository was archived by the owner on Feb 24, 2022. It is now read-only.

Commit 60dbb78

Browse files
committed
Add live plotting of key metrics (experimental, unstable)
1 parent f45e68b commit 60dbb78

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import time
2+
from collections import OrderedDict
3+
4+
import numpy as np
5+
import pandas as pd
6+
import matplotlib.pyplot as plt
7+
8+
class Metric(object):
9+
"""Named sequence of x and y values, with optional plotting helpers."""
10+
11+
def __init__(self, name):
12+
self.name = name
13+
self.reset()
14+
15+
def collect(self, x, y):
16+
self.xdata.append(x)
17+
self.ydata.append(y)
18+
19+
def plot(self, ax):
20+
self.plot_obj, = ax.plot(self.xdata, self.ydata, 'o-', label=self.name)
21+
22+
def refresh(self):
23+
self.plot_obj.set_data(self.xdata, self.ydata)
24+
25+
def reset(self):
26+
self.xdata = []
27+
self.ydata = []
28+
29+
30+
class Reporter(object):
31+
"""Collect metrics, analyze and report summary statistics."""
32+
33+
def __init__(self, metrics=[], live_plot=False):
34+
self.metrics = OrderedDict()
35+
self.live_plot = live_plot
36+
37+
for name in metrics:
38+
self.metrics[name] = Metric(name)
39+
40+
if self.live_plot:
41+
if not plt.isinteractive():
42+
plt.ion()
43+
self.plot()
44+
45+
print "Reporter.__init__(): Initialized with metrics: {}".format(metrics) # [debug]
46+
47+
def collect(self, name, x, y):
48+
if not name in self.metrics:
49+
self.metrics[name] = Metric(name)
50+
if self.live_plot:
51+
self.metrics[name].plot(self.ax)
52+
self.ax.legend() # add new metric to legend
53+
print "Reporter.collect(): New metric added: {}".format(name) # [debug]
54+
self.metrics[name].collect(x, y)
55+
if self.live_plot:
56+
self.metrics[name].refresh()
57+
58+
def plot(self):
59+
if not hasattr(self, 'fig') or not hasattr(self, 'ax'):
60+
self.fig, self.ax = plt.subplots()
61+
for name in self.metrics:
62+
self.metrics[name].plot(self.ax)
63+
#self.ax.set_autoscalex_on(True)
64+
#self.ax.set_autoscaley_on(True)
65+
self.ax.grid()
66+
self.ax.legend()
67+
else:
68+
for name in self.metrics:
69+
self.metrics[name].refresh()
70+
self.refresh_plot()
71+
72+
def refresh_plot(self):
73+
self.ax.relim()
74+
self.ax.autoscale_view()
75+
self.fig.canvas.draw()
76+
self.fig.canvas.flush_events()
77+
plt.draw()
78+
79+
def show_plot(self):
80+
if plt.isinteractive():
81+
plt.ioff()
82+
self.plot()
83+
plt.show()
84+
85+
def summary(self):
86+
return [pd.Series(metric.ydata, index=metric.xdata, name=name) for name, metric in self.metrics.iteritems()]
87+
88+
def reset(self):
89+
for name in self.metrics:
90+
self.metrics[name].reset()
91+
if self.live_plot:
92+
self.metrics[name].refresh()
93+
94+
95+
def test_reporter():
96+
plt.ion()
97+
rep = Reporter(metrics=['reward', 'flubber'], live_plot=True)
98+
for i in xrange(100):
99+
rep.collect('reward', i, np.random.random())
100+
if i % 10 == 1:
101+
rep.collect('flubber', i, np.random.random() * 2 + 1)
102+
rep.refresh_plot()
103+
time.sleep(0.01)
104+
rep.plot()
105+
summary = rep.summary()
106+
print "Summary ({} metrics):-".format(len(summary))
107+
for metric in summary:
108+
print "Name: {}, samples: {}, type: {}".format(metric.name, len(metric), metric.dtype)
109+
print "Mean: {}, s.d.: {}".format(metric.mean(), metric.std())
110+
#print metric[:5] # [debug]
111+
plt.ioff()
112+
plt.show()
113+
114+
115+
if __name__ == '__main__':
116+
test_reporter()

projects/smartcab/smartcab/simulator.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import random
44
import importlib
55

6+
import numpy as np
7+
8+
from analysis import Reporter
9+
610
class Simulator(object):
711
"""Simulates agents in a dynamic smartcab environment.
812
@@ -21,7 +25,7 @@ class Simulator(object):
2125
'orange' : (255, 128, 0)
2226
}
2327

24-
def __init__(self, env, size=None, update_delay=1.0, display=True):
28+
def __init__(self, env, size=None, update_delay=1.0, display=True, live_plot=False):
2529
self.env = env
2630
self.size = size if size is not None else ((self.env.grid_size[0] + 1) * self.env.block_size, (self.env.grid_size[1] + 1) * self.env.block_size)
2731
self.width, self.height = self.size
@@ -34,7 +38,7 @@ def __init__(self, env, size=None, update_delay=1.0, display=True):
3438
self.start_time = None
3539
self.current_time = 0.0
3640
self.last_updated = 0.0
37-
self.update_delay = update_delay
41+
self.update_delay = update_delay # duration between each step (in secs)
3842

3943
self.display = display
4044
if self.display:
@@ -59,8 +63,14 @@ def __init__(self, env, size=None, update_delay=1.0, display=True):
5963
self.display = False
6064
print "Simulator.__init__(): Error initializing GUI objects; display disabled.\n{}: {}".format(e.__class__.__name__, e)
6165

66+
# Setup metrics to report
67+
self.live_plot = live_plot
68+
self.rep = Reporter(metrics=['net_reward', 'avg_net_reward', 'final_deadline', 'success'], live_plot=self.live_plot)
69+
self.avg_net_reward_window = 10
70+
6271
def run(self, n_trials=1):
6372
self.quit = False
73+
self.rep.reset()
6474
for trial in xrange(n_trials):
6575
print "Simulator.run(): Trial {}".format(trial) # [debug]
6676
self.env.reset()
@@ -90,6 +100,7 @@ def run(self, n_trials=1):
90100
# Update environment
91101
if self.current_time - self.last_updated >= self.update_delay:
92102
self.env.step()
103+
# TODO: Log step data
93104
self.last_updated = self.current_time
94105

95106
# Render GUI and sleep
@@ -105,6 +116,22 @@ def run(self, n_trials=1):
105116
if self.quit:
106117
break
107118

119+
# Collect/update metrics
120+
self.rep.collect('net_reward', trial, self.env.trial_data['net_reward']) # total reward obtained in this trial
121+
self.rep.collect('avg_net_reward', trial, np.mean(self.rep.metrics['net_reward'].ydata[-self.avg_net_reward_window:])) # rolling mean of reward
122+
self.rep.collect('final_deadline', trial, self.env.trial_data['final_deadline']) # final deadline value (time remaining)
123+
self.rep.collect('success', trial, self.env.trial_data['success'])
124+
if self.live_plot:
125+
self.rep.refresh_plot() # autoscales axes, draws stuff and flushes events
126+
127+
# Report final metrics
128+
if self.display:
129+
self.pygame.display.quit() # need to shutdown pygame before showing metrics plot
130+
# TODO: Figure out why having both game and plot displays makes things crash!
131+
132+
if self.live_plot:
133+
self.rep.show_plot() # holds till user closes plot window
134+
108135
def render(self):
109136
# Clear screen
110137
self.screen.fill(self.bg_color)

0 commit comments

Comments
 (0)