-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
127 lines (102 loc) · 3.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# -*- coding: utf-8 -*-
"""utils.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1TFi1zqTuwH9BqIQSJgGgxmA7rxm9SGIU
"""
import matplotlib.pyplot as plt
import numpy as np
import gym
def plotLearning(x, scores, epsilons, filename, lines=None):
fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax.plot(x, epsilons, color="C0")
ax.set_xlabel("Game", color="C0")
ax.set_ylabel("Epsilon", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")
N = len(scores)
running_avg = np.empty(N)
for t in range(N):
running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])
ax2.scatter(x, running_avg, color="C1")
#ax2.xaxis.tick_top()
ax2.axes.get_xaxis().set_visible(False)
ax2.yaxis.tick_right()
#ax2.set_xlabel('x label 2', color="C1")
ax2.set_ylabel('Score', color="C1")
#ax2.xaxis.set_label_position('top')
ax2.yaxis.set_label_position('right')
#ax2.tick_params(axis='x', colors="C1")
ax2.tick_params(axis='y', colors="C1")
if lines is not None:
for line in lines:
plt.axvline(x=line)
plt.savefig(filename)
class SkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
super(SkipEnv, self).__init__(env)
self._skip = skip
def step(self, action):
t_reward = 0.0
done = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
t_reward += reward
if done:
break
return obs, t_reward, done, info
def reset(self):
self._obs_buffer = []
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class PreProcessFrame(gym.ObservationWrapper):
def __init__(self, env=None):
super(PreProcessFrame, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(80,80,1), dtype=np.uint8)
def observation(self, obs):
return PreProcessFrame.process(obs)
@staticmethod
def process(frame):
new_frame = np.reshape(frame, frame.shape).astype(np.float32)
new_frame = 0.299*new_frame[:,:,0] + 0.587*new_frame[:,:,1] + \
0.114*new_frame[:,:,2]
new_frame = new_frame[35:195:2, ::2].reshape(80,80,1)
return new_frame.astype(np.uint8)
class MoveImgChannel(gym.ObservationWrapper):
def __init__(self, env):
super(MoveImgChannel, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0.0, high=1.0,
shape=(self.observation_space.shape[-1],
self.observation_space.shape[0],
self.observation_space.shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaleFrame(gym.ObservationWrapper):
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(BufferWrapper, self).__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(n_steps, axis=0),
env.observation_space.high.repeat(n_steps, axis=0),
dtype=np.float32)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=np.float32)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym.make(env_name)
env = SkipEnv(env)
env = PreProcessFrame(env)
env = MoveImgChannel(env)
env = BufferWrapper(env, 4)
return ScaleFrame(env)