@@ -19,49 +19,38 @@ class SumTree(object):
19
19
"""
20
20
This SumTree code is modified version and the original code is from:
21
21
https://github.com/jaara/AI-blog/blob/master/SumTree.py
22
-
22
+
23
23
Story the data with it priority in tree and data frameworks.
24
24
"""
25
25
data_pointer = 0
26
26
27
27
def __init__ (self , capacity ):
28
- self .capacity = capacity # for all priority values
29
- self .tree = np .zeros (2 * capacity - 1 )
28
+ self .capacity = capacity # for all priority values
29
+ self .tree = np .zeros (2 * capacity - 1 )
30
30
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
31
31
# size: capacity - 1 size: capacity
32
- self .data = np .zeros (capacity , dtype = object ) # for all transitions
32
+ self .data = np .zeros (capacity , dtype = object ) # for all transitions
33
33
# [--------------data frame-------------]
34
34
# size: capacity
35
35
36
- def add_new_priority (self , p , data ):
37
- leaf_idx = self .data_pointer + self .capacity - 1
38
-
39
- self .data [self .data_pointer ] = data # update data_frame
40
- self .update (leaf_idx , p ) # update tree_frame
36
+ def add (self , p , data ):
37
+ tree_idx = self .data_pointer + self .capacity - 1
38
+ self .data [self .data_pointer ] = data # update data_frame
39
+ self .update (tree_idx , p ) # update tree_frame
41
40
42
41
self .data_pointer += 1
43
42
if self .data_pointer >= self .capacity : # replace when exceed the capacity
44
43
self .data_pointer = 0
45
44
46
45
def update (self , tree_idx , p ):
47
46
change = p - self .tree [tree_idx ]
48
-
49
47
self .tree [tree_idx ] = p
50
- self ._propagate_change (tree_idx , change )
51
-
52
- def _propagate_change (self , tree_idx , change ):
53
- """change the sum of priority value in all parent nodes"""
54
- parent_idx = (tree_idx - 1 ) // 2
55
- self .tree [parent_idx ] += change
56
- if parent_idx != 0 :
57
- self ._propagate_change (parent_idx , change )
58
-
59
- def get_leaf (self , lower_bound ):
60
- leaf_idx = self ._retrieve (lower_bound ) # search the max leaf priority based on the lower_bound
61
- data_idx = leaf_idx - self .capacity + 1
62
- return [leaf_idx , self .tree [leaf_idx ], self .data [data_idx ]]
48
+ # then propagate the change through tree
49
+ while tree_idx != 0 : # this method is faster than the recursive loop in the reference code
50
+ tree_idx = (tree_idx - 1 ) // 2
51
+ self .tree [tree_idx ] += change
63
52
64
- def _retrieve (self , lower_bound , parent_idx = 0 ):
53
+ def get_leaf (self , v ):
65
54
"""
66
55
Tree structure and array storage:
67
56
@@ -75,32 +64,36 @@ def _retrieve(self, lower_bound, parent_idx=0):
75
64
Array type for storing:
76
65
[0,1,2,3,4,5,6]
77
66
"""
78
- left_child_idx = 2 * parent_idx + 1
79
- right_child_idx = left_child_idx + 1
67
+ parent_idx = 0
68
+ while True : # the while loop is faster than the method in the reference code
69
+ cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
70
+ cr_idx = cl_idx + 1
71
+ if cl_idx >= len (self .tree ): # reach bottom, end search
72
+ leaf_idx = parent_idx
73
+ break
74
+ else : # downward search, always search for a higher priority node
75
+ if v <= self .tree [cl_idx ]:
76
+ parent_idx = cl_idx
77
+ else :
78
+ v -= self .tree [cl_idx ]
79
+ parent_idx = cr_idx
80
80
81
- if left_child_idx >= len (self .tree ): # end search when no more child
82
- return parent_idx
83
-
84
- if self .tree [left_child_idx ] == self .tree [right_child_idx ]:
85
- return self ._retrieve (lower_bound , np .random .choice ([left_child_idx , right_child_idx ]))
86
- if lower_bound <= self .tree [left_child_idx ]: # downward search, always search for a higher priority node
87
- return self ._retrieve (lower_bound , left_child_idx )
88
- else :
89
- return self ._retrieve (lower_bound - self .tree [left_child_idx ], right_child_idx )
81
+ data_idx = leaf_idx - self .capacity + 1
82
+ return leaf_idx , self .tree [leaf_idx ], self .data [data_idx ]
90
83
91
84
@property
92
- def root_priority (self ):
93
- return self .tree [0 ] # the root
85
+ def total_p (self ):
86
+ return self .tree [0 ] # the root
94
87
95
88
96
- class Memory (object ): # stored as ( s, a, r, s_ ) in SumTree
89
+ class Memory (object ): # stored as ( s, a, r, s_ ) in SumTree
97
90
"""
98
91
This SumTree code is modified version and the original code is from:
99
92
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
100
93
"""
101
94
epsilon = 0.01 # small amount to avoid zero priority
102
- alpha = 0.6 # [0~1] convert the importance of TD error to priority
103
- beta = 0.4 # importance-sampling, from initial value increasing to 1
95
+ alpha = 0.6 # [0~1] convert the importance of TD error to priority
96
+ beta = 0.4 # importance-sampling, from initial value increasing to 1
104
97
beta_increment_per_sampling = 0.001
105
98
abs_err_upper = 1. # clipped abs error
106
99
@@ -111,37 +104,29 @@ def store(self, transition):
111
104
max_p = np .max (self .tree .tree [- self .tree .capacity :])
112
105
if max_p == 0 :
113
106
max_p = self .abs_err_upper
114
- self .tree .add_new_priority (max_p , transition ) # set the max p for new p
107
+ self .tree .add (max_p , transition ) # set the max p for new p
115
108
116
109
def sample (self , n ):
117
- batch_idx , batch_memory , ISWeights = [], [], []
118
- segment = self .tree .root_priority / n
119
- self .beta = np .min ([1 , self .beta + self .beta_increment_per_sampling ]) # max = 1
110
+ b_idx , b_memory , ISWeights = np . empty (( n ,), dtype = np . int32 ), np . empty (( n , self . tree . data [ 0 ]. size )), np . empty (( n , 1 ))
111
+ pri_seg = self .tree .total_p / n # priority segment
112
+ self .beta = np .min ([1. , self .beta + self .beta_increment_per_sampling ]) # max = 1
120
113
121
- min_prob = np .min (self .tree .tree [- self .tree .capacity :]) / self .tree .root_priority
122
- maxiwi = np .power (self .tree .capacity * min_prob , - self .beta ) # for later normalizing ISWeights
114
+ min_prob = np .min (self .tree .tree [- self .tree .capacity :]) / self .tree .total_p # for later calculate ISweight
123
115
for i in range (n ):
124
- a = segment * i
125
- b = segment * (i + 1 )
126
- lower_bound = np .random .uniform (a , b )
127
- idx , p , data = self .tree .get_leaf (lower_bound )
128
- prob = p / self .tree .root_priority
129
- ISWeights .append (self .tree .capacity * prob )
130
- batch_idx .append (idx )
131
- batch_memory .append (data )
132
-
133
- ISWeights = np .vstack (ISWeights )
134
- ISWeights = np .power (ISWeights , - self .beta ) / maxiwi # normalize
135
- return batch_idx , np .vstack (batch_memory ), ISWeights
136
-
137
- def update (self , idx , error ):
138
- p = self ._get_priority (error )
139
- self .tree .update (idx , p )
140
-
141
- def _get_priority (self , error ):
142
- error += self .epsilon # avoid 0
143
- clipped_error = np .clip (error , 0 , self .abs_err_upper )
144
- return np .power (clipped_error , self .alpha )
116
+ a , b = pri_seg * i , pri_seg * (i + 1 )
117
+ v = np .random .uniform (a , b )
118
+ idx , p , data = self .tree .get_leaf (v )
119
+ prob = p / self .tree .total_p
120
+ ISWeights [i , 0 ] = np .power (prob / min_prob , - self .beta )
121
+ b_idx [i ], b_memory [i , :] = idx , data
122
+ return b_idx , b_memory , ISWeights
123
+
124
+ def batch_update (self , tree_idx , abs_errors ):
125
+ abs_errors += self .epsilon # convert to abs and avoid 0
126
+ clipped_errors = np .minimum (abs_errors , self .abs_err_upper )
127
+ ps = np .power (clipped_errors , self .alpha )
128
+ for ti , p in zip (tree_idx , ps ):
129
+ self .tree .update (ti , p )
145
130
146
131
147
132
class DQNPrioritizedReplay :
@@ -197,15 +182,15 @@ def __init__(
197
182
self .cost_his = []
198
183
199
184
def _build_net (self ):
200
- def build_layers (s , c_names , n_l1 , w_initializer , b_initializer ):
185
+ def build_layers (s , c_names , n_l1 , w_initializer , b_initializer , trainable ):
201
186
with tf .variable_scope ('l1' ):
202
- w1 = tf .get_variable ('w1' , [self .n_features , n_l1 ], initializer = w_initializer , collections = c_names )
203
- b1 = tf .get_variable ('b1' , [1 , n_l1 ], initializer = b_initializer , collections = c_names )
187
+ w1 = tf .get_variable ('w1' , [self .n_features , n_l1 ], initializer = w_initializer , collections = c_names , trainable = trainable )
188
+ b1 = tf .get_variable ('b1' , [1 , n_l1 ], initializer = b_initializer , collections = c_names , trainable = trainable )
204
189
l1 = tf .nn .relu (tf .matmul (s , w1 ) + b1 )
205
190
206
191
with tf .variable_scope ('l2' ):
207
- w2 = tf .get_variable ('w2' , [n_l1 , self .n_actions ], initializer = w_initializer , collections = c_names )
208
- b2 = tf .get_variable ('b2' , [1 , self .n_actions ], initializer = b_initializer , collections = c_names )
192
+ w2 = tf .get_variable ('w2' , [n_l1 , self .n_actions ], initializer = w_initializer , collections = c_names , trainable = trainable )
193
+ b2 = tf .get_variable ('b2' , [1 , self .n_actions ], initializer = b_initializer , collections = c_names , trainable = trainable )
209
194
out = tf .matmul (l1 , w2 ) + b2
210
195
return out
211
196
@@ -219,7 +204,7 @@ def build_layers(s, c_names, n_l1, w_initializer, b_initializer):
219
204
['eval_net_params' , tf .GraphKeys .GLOBAL_VARIABLES ], 20 , \
220
205
tf .random_normal_initializer (0. , 0.3 ), tf .constant_initializer (0.1 ) # config of layers
221
206
222
- self .q_eval = build_layers (self .s , c_names , n_l1 , w_initializer , b_initializer )
207
+ self .q_eval = build_layers (self .s , c_names , n_l1 , w_initializer , b_initializer , True )
223
208
224
209
with tf .variable_scope ('loss' ):
225
210
if self .prioritized :
@@ -234,7 +219,7 @@ def build_layers(s, c_names, n_l1, w_initializer, b_initializer):
234
219
self .s_ = tf .placeholder (tf .float32 , [None , self .n_features ], name = 's_' ) # input
235
220
with tf .variable_scope ('target_net' ):
236
221
c_names = ['target_net_params' , tf .GraphKeys .GLOBAL_VARIABLES ]
237
- self .q_next = build_layers (self .s_ , c_names , n_l1 , w_initializer , b_initializer )
222
+ self .q_next = build_layers (self .s_ , c_names , n_l1 , w_initializer , b_initializer , False )
238
223
239
224
def store_transition (self , s , a , r , s_ ):
240
225
if self .prioritized : # prioritized replay
@@ -285,9 +270,7 @@ def learn(self):
285
270
feed_dict = {self .s : batch_memory [:, :self .n_features ],
286
271
self .q_target : q_target ,
287
272
self .ISWeights : ISWeights })
288
- for i in range (len (tree_idx )): # update priority
289
- idx = tree_idx [i ]
290
- self .memory .update (idx , abs_errors [i ])
273
+ self .memory .batch_update (tree_idx , abs_errors ) # update priority
291
274
else :
292
275
_ , self .cost = self .sess .run ([self ._train_op , self .loss ],
293
276
feed_dict = {self .s : batch_memory [:, :self .n_features ],
0 commit comments