-
Notifications
You must be signed in to change notification settings - Fork 146
/
utils.py
111 lines (92 loc) · 3.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import collections
import cv2
import numpy as np
import matplotlib.pyplot as plt
import gym
def plot_learning_curve(x, scores, epsilons, filename, lines=None):
fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax.plot(x, epsilons, color="C0")
ax.set_xlabel("Training Steps", color="C0")
ax.set_ylabel("Epsilon", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")
N = len(scores)
running_avg = np.empty(N)
for t in range(N):
running_avg[t] = np.mean(scores[max(0, t-20):(t+1)])
ax2.scatter(x, running_avg, color="C1")
ax2.axes.get_xaxis().set_visible(False)
ax2.yaxis.tick_right()
ax2.set_ylabel('Score', color="C1")
ax2.yaxis.set_label_position('right')
ax2.tick_params(axis='y', colors="C1")
if lines is not None:
for line in lines:
plt.axvline(x=line)
plt.savefig(filename)
class RepeatActionAndMaxFrame(gym.Wrapper):
""" modified from:
https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On/blob/master/Chapter06/lib/wrappers.py
"""
def __init__(self, env=None, repeat=4):
super(RepeatActionAndMaxFrame, self).__init__(env)
self.repeat = repeat
self.shape = env.observation_space.low.shape
self.frame_buffer = np.zeros_like((2,self.shape))
def step(self, action):
t_reward = 0.0
done = False
for i in range(self.repeat):
obs, reward, done, info = self.env.step(action)
t_reward += reward
idx = i % 2
self.frame_buffer[idx] = obs
if done:
break
max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
return max_frame, t_reward, done, info
def reset(self):
obs = self.env.reset()
self.frame_buffer = np.zeros_like((2,self.shape))
self.frame_buffer[0] = obs
return obs
class PreprocessFrame(gym.ObservationWrapper):
def __init__(self, shape, env=None):
super(PreprocessFrame, self).__init__(env)
self.shape=(shape[2], shape[0], shape[1])
self.observation_space = gym.spaces.Box(low=0, high=1.0,
shape=self.shape,dtype=np.float32)
def observation(self, obs):
new_frame = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
resized_screen = cv2.resize(new_frame, self.shape[1:],
interpolation=cv2.INTER_AREA)
new_obs = np.array(resized_screen, dtype=np.uint8).reshape(self.shape)
new_obs = np.swapaxes(new_obs, 2,0)
new_obs = new_obs / 255.0
return new_obs
class StackFrames(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(StackFrames, self).__init__(env)
self.observation_space = gym.spaces.Box(
env.observation_space.low.repeat(n_steps, axis=0),
env.observation_space.high.repeat(n_steps, axis=0),
dtype=np.float32)
self.stack = collections.deque(maxlen=n_steps)
def reset(self):
self.stack.clear()
observation = self.env.reset()
for _ in range(self.stack.maxlen):
self.stack.append(observation)
return np.array(self.stack).reshape(self.observation_space.low.shape)
def observation(self, observation):
self.stack.append(observation)
obs = np.array(self.stack).reshape(self.observation_space.low.shape)
return obs
def make_env(env_name, shape=(84,84,1), skip=4):
env = gym.make(env_name)
env = RepeatActionAndMaxFrame(env, skip)
env = PreprocessFrame(shape, env)
env = StackFrames(env, skip)
return env