diff --git a/common.py b/common.py index c9c166d..e0e94e7 100644 --- a/common.py +++ b/common.py @@ -1,19 +1,24 @@ from collections import defaultdict +import matplotlib +from matplotlib import cm +matplotlib.use('Agg') from mpl_toolkits.mplot3d import Axes3D import numpy as np import pylab as plt from random import randint, random import cPickle +from datetime import datetime HIT, STICK = 1, 0 + def action_value_to_value_function(action_value_function): value_function = defaultdict(float) keys = action_value_function.keys() for key in keys: - dealer, player, action = key[0], key[1], key[2] + dealer, player, action = key hit_reward = action_value_function.get((dealer, player, HIT)) stick_reward = action_value_function.get((dealer, player, STICK)) @@ -64,8 +69,11 @@ def plot_value_function(value_function, title): plt.title(title) ax.set_xlabel("Dealer Showing") ax.set_ylabel("Player Sum") - ax.plot_surface(X, Y, Z, rstride=1, cstride=1) - plt.show() + ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet) + fname = title+str(datetime.now())+".png" + print "Saving", fname + plt.savefig(fname) + # plt.show() def save(data, file):