Skip to content

Commit 5e76b69

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
simplify code
1 parent c603541 commit 5e76b69

File tree

1 file changed

+60
-77
lines changed

1 file changed

+60
-77
lines changed

contents/5.2_Prioritized_Replay_DQN/RL_brain.py

Lines changed: 60 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,38 @@ class SumTree(object):
1919
"""
2020
This SumTree code is modified version and the original code is from:
2121
https://github.com/jaara/AI-blog/blob/master/SumTree.py
22-
22+
2323
Story the data with it priority in tree and data frameworks.
2424
"""
2525
data_pointer = 0
2626

2727
def __init__(self, capacity):
28-
self.capacity = capacity # for all priority values
29-
self.tree = np.zeros(2*capacity - 1)
28+
self.capacity = capacity # for all priority values
29+
self.tree = np.zeros(2 * capacity - 1)
3030
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
3131
# size: capacity - 1 size: capacity
32-
self.data = np.zeros(capacity, dtype=object) # for all transitions
32+
self.data = np.zeros(capacity, dtype=object) # for all transitions
3333
# [--------------data frame-------------]
3434
# size: capacity
3535

36-
def add_new_priority(self, p, data):
37-
leaf_idx = self.data_pointer + self.capacity - 1
38-
39-
self.data[self.data_pointer] = data # update data_frame
40-
self.update(leaf_idx, p) # update tree_frame
36+
def add(self, p, data):
37+
tree_idx = self.data_pointer + self.capacity - 1
38+
self.data[self.data_pointer] = data # update data_frame
39+
self.update(tree_idx, p) # update tree_frame
4140

4241
self.data_pointer += 1
4342
if self.data_pointer >= self.capacity: # replace when exceed the capacity
4443
self.data_pointer = 0
4544

4645
def update(self, tree_idx, p):
4746
change = p - self.tree[tree_idx]
48-
4947
self.tree[tree_idx] = p
50-
self._propagate_change(tree_idx, change)
51-
52-
def _propagate_change(self, tree_idx, change):
53-
"""change the sum of priority value in all parent nodes"""
54-
parent_idx = (tree_idx - 1) // 2
55-
self.tree[parent_idx] += change
56-
if parent_idx != 0:
57-
self._propagate_change(parent_idx, change)
58-
59-
def get_leaf(self, lower_bound):
60-
leaf_idx = self._retrieve(lower_bound) # search the max leaf priority based on the lower_bound
61-
data_idx = leaf_idx - self.capacity + 1
62-
return [leaf_idx, self.tree[leaf_idx], self.data[data_idx]]
48+
# then propagate the change through tree
49+
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
50+
tree_idx = (tree_idx - 1) // 2
51+
self.tree[tree_idx] += change
6352

64-
def _retrieve(self, lower_bound, parent_idx=0):
53+
def get_leaf(self, v):
6554
"""
6655
Tree structure and array storage:
6756
@@ -75,32 +64,36 @@ def _retrieve(self, lower_bound, parent_idx=0):
7564
Array type for storing:
7665
[0,1,2,3,4,5,6]
7766
"""
78-
left_child_idx = 2 * parent_idx + 1
79-
right_child_idx = left_child_idx + 1
67+
parent_idx = 0
68+
while True: # the while loop is faster than the method in the reference code
69+
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
70+
cr_idx = cl_idx + 1
71+
if cl_idx >= len(self.tree): # reach bottom, end search
72+
leaf_idx = parent_idx
73+
break
74+
else: # downward search, always search for a higher priority node
75+
if v <= self.tree[cl_idx]:
76+
parent_idx = cl_idx
77+
else:
78+
v -= self.tree[cl_idx]
79+
parent_idx = cr_idx
8080

81-
if left_child_idx >= len(self.tree): # end search when no more child
82-
return parent_idx
83-
84-
if self.tree[left_child_idx] == self.tree[right_child_idx]:
85-
return self._retrieve(lower_bound, np.random.choice([left_child_idx, right_child_idx]))
86-
if lower_bound <= self.tree[left_child_idx]: # downward search, always search for a higher priority node
87-
return self._retrieve(lower_bound, left_child_idx)
88-
else:
89-
return self._retrieve(lower_bound-self.tree[left_child_idx], right_child_idx)
81+
data_idx = leaf_idx - self.capacity + 1
82+
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
9083

9184
@property
92-
def root_priority(self):
93-
return self.tree[0] # the root
85+
def total_p(self):
86+
return self.tree[0] # the root
9487

9588

96-
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
89+
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
9790
"""
9891
This SumTree code is modified version and the original code is from:
9992
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
10093
"""
10194
epsilon = 0.01 # small amount to avoid zero priority
102-
alpha = 0.6 # [0~1] convert the importance of TD error to priority
103-
beta = 0.4 # importance-sampling, from initial value increasing to 1
95+
alpha = 0.6 # [0~1] convert the importance of TD error to priority
96+
beta = 0.4 # importance-sampling, from initial value increasing to 1
10497
beta_increment_per_sampling = 0.001
10598
abs_err_upper = 1. # clipped abs error
10699

@@ -111,37 +104,29 @@ def store(self, transition):
111104
max_p = np.max(self.tree.tree[-self.tree.capacity:])
112105
if max_p == 0:
113106
max_p = self.abs_err_upper
114-
self.tree.add_new_priority(max_p, transition) # set the max p for new p
107+
self.tree.add(max_p, transition) # set the max p for new p
115108

116109
def sample(self, n):
117-
batch_idx, batch_memory, ISWeights = [], [], []
118-
segment = self.tree.root_priority / n
119-
self.beta = np.min([1, self.beta + self.beta_increment_per_sampling]) # max = 1
110+
b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
111+
pri_seg = self.tree.total_p / n # priority segment
112+
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1
120113

121-
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.root_priority
122-
maxiwi = np.power(self.tree.capacity * min_prob, -self.beta) # for later normalizing ISWeights
114+
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
123115
for i in range(n):
124-
a = segment * i
125-
b = segment * (i + 1)
126-
lower_bound = np.random.uniform(a, b)
127-
idx, p, data = self.tree.get_leaf(lower_bound)
128-
prob = p / self.tree.root_priority
129-
ISWeights.append(self.tree.capacity * prob)
130-
batch_idx.append(idx)
131-
batch_memory.append(data)
132-
133-
ISWeights = np.vstack(ISWeights)
134-
ISWeights = np.power(ISWeights, -self.beta) / maxiwi # normalize
135-
return batch_idx, np.vstack(batch_memory), ISWeights
136-
137-
def update(self, idx, error):
138-
p = self._get_priority(error)
139-
self.tree.update(idx, p)
140-
141-
def _get_priority(self, error):
142-
error += self.epsilon # avoid 0
143-
clipped_error = np.clip(error, 0, self.abs_err_upper)
144-
return np.power(clipped_error, self.alpha)
116+
a, b = pri_seg * i, pri_seg * (i + 1)
117+
v = np.random.uniform(a, b)
118+
idx, p, data = self.tree.get_leaf(v)
119+
prob = p / self.tree.total_p
120+
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
121+
b_idx[i], b_memory[i, :] = idx, data
122+
return b_idx, b_memory, ISWeights
123+
124+
def batch_update(self, tree_idx, abs_errors):
125+
abs_errors += self.epsilon # convert to abs and avoid 0
126+
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
127+
ps = np.power(clipped_errors, self.alpha)
128+
for ti, p in zip(tree_idx, ps):
129+
self.tree.update(ti, p)
145130

146131

147132
class DQNPrioritizedReplay:
@@ -197,15 +182,15 @@ def __init__(
197182
self.cost_his = []
198183

199184
def _build_net(self):
200-
def build_layers(s, c_names, n_l1, w_initializer, b_initializer):
185+
def build_layers(s, c_names, n_l1, w_initializer, b_initializer, trainable):
201186
with tf.variable_scope('l1'):
202-
w1 = tf.get_variable('w1', [self.n_features, n_l1], initializer=w_initializer, collections=c_names)
203-
b1 = tf.get_variable('b1', [1, n_l1], initializer=b_initializer, collections=c_names)
187+
w1 = tf.get_variable('w1', [self.n_features, n_l1], initializer=w_initializer, collections=c_names, trainable=trainable)
188+
b1 = tf.get_variable('b1', [1, n_l1], initializer=b_initializer, collections=c_names, trainable=trainable)
204189
l1 = tf.nn.relu(tf.matmul(s, w1) + b1)
205190

206191
with tf.variable_scope('l2'):
207-
w2 = tf.get_variable('w2', [n_l1, self.n_actions], initializer=w_initializer, collections=c_names)
208-
b2 = tf.get_variable('b2', [1, self.n_actions], initializer=b_initializer, collections=c_names)
192+
w2 = tf.get_variable('w2', [n_l1, self.n_actions], initializer=w_initializer, collections=c_names, trainable=trainable)
193+
b2 = tf.get_variable('b2', [1, self.n_actions], initializer=b_initializer, collections=c_names, trainable=trainable)
209194
out = tf.matmul(l1, w2) + b2
210195
return out
211196

@@ -219,7 +204,7 @@ def build_layers(s, c_names, n_l1, w_initializer, b_initializer):
219204
['eval_net_params', tf.GraphKeys.GLOBAL_VARIABLES], 20, \
220205
tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1) # config of layers
221206

222-
self.q_eval = build_layers(self.s, c_names, n_l1, w_initializer, b_initializer)
207+
self.q_eval = build_layers(self.s, c_names, n_l1, w_initializer, b_initializer, True)
223208

224209
with tf.variable_scope('loss'):
225210
if self.prioritized:
@@ -234,7 +219,7 @@ def build_layers(s, c_names, n_l1, w_initializer, b_initializer):
234219
self.s_ = tf.placeholder(tf.float32, [None, self.n_features], name='s_') # input
235220
with tf.variable_scope('target_net'):
236221
c_names = ['target_net_params', tf.GraphKeys.GLOBAL_VARIABLES]
237-
self.q_next = build_layers(self.s_, c_names, n_l1, w_initializer, b_initializer)
222+
self.q_next = build_layers(self.s_, c_names, n_l1, w_initializer, b_initializer, False)
238223

239224
def store_transition(self, s, a, r, s_):
240225
if self.prioritized: # prioritized replay
@@ -285,9 +270,7 @@ def learn(self):
285270
feed_dict={self.s: batch_memory[:, :self.n_features],
286271
self.q_target: q_target,
287272
self.ISWeights: ISWeights})
288-
for i in range(len(tree_idx)): # update priority
289-
idx = tree_idx[i]
290-
self.memory.update(idx, abs_errors[i])
273+
self.memory.batch_update(tree_idx, abs_errors) # update priority
291274
else:
292275
_, self.cost = self.sess.run([self._train_op, self.loss],
293276
feed_dict={self.s: batch_memory[:, :self.n_features],

0 commit comments

Comments
 (0)