Skip to content

Commit db08ca8

Browse files
committed
update prioritized_replay_dqn
1 parent 184ea0b commit db08ca8

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

contents/5.2_Prioritized_Replay_DQN/RL_brain.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class SumTree(object):
3333

3434
def __init__(self, capacity):
3535
self.capacity = capacity # for all priority values
36-
self.tree = np.zeros(2 * capacity - 1)
36+
self.tree = np.zeros(2 * capacity - 1) #给定树的大小,一个sumTree的点和
3737
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
3838
# size: capacity - 1 size: capacity
39-
self.data = np.zeros(capacity, dtype=object) # for all transitions
39+
self.data = np.zeros(capacity, dtype=object) # for all transitions 都记录到了叶子节点中
4040
# [--------------data frame-------------]
4141
# size: capacity
4242

@@ -45,7 +45,9 @@ def add(self, p, data):
4545
self.data[self.data_pointer] = data # update data_frame
4646
self.update(tree_idx, p) # update tree_frame
4747

48-
self.data_pointer += 1
48+
self.data_pointer += 1 # 数据点增加
49+
50+
#完成一轮后要重新开始添到capacity对应的节点中
4951
if self.data_pointer >= self.capacity: # replace when exceed the capacity
5052
self.data_pointer = 0
5153

@@ -55,7 +57,7 @@ def update(self, tree_idx, p):
5557
# then propagate the change through tree
5658
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
5759
tree_idx = (tree_idx - 1) // 2
58-
self.tree[tree_idx] += change
60+
self.tree[tree_idx] += change #sumTree
5961

6062
def get_leaf(self, v):
6163
"""
@@ -73,10 +75,10 @@ def get_leaf(self, v):
7375
"""
7476
parent_idx = 0
7577
while True: # the while loop is faster than the method in the reference code
76-
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
77-
cr_idx = cl_idx + 1
78+
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids,如何表示左子节点位置
79+
cr_idx = cl_idx + 1 #如何表示父节点的子节点右节点位置
7880
if cl_idx >= len(self.tree): # reach bottom, end search
79-
leaf_idx = parent_idx
81+
leaf_idx = parent_idx #到达bottom结束search
8082
break
8183
else: # downward search, always search for a higher priority node
8284
if v <= self.tree[cl_idx]:
@@ -108,7 +110,7 @@ def __init__(self, capacity):
108110
self.tree = SumTree(capacity)
109111

110112
def store(self, transition):
111-
max_p = np.max(self.tree.tree[-self.tree.capacity:]) #找到优先级最高的
113+
max_p = np.max(self.tree.tree[-self.tree.capacity:]) #找到优先级最高的那个位置
112114
if max_p == 0: #如果为0,设置一个默认值
113115
max_p = self.abs_err_upper
114116
self.tree.add(max_p, transition) # set the max p for new p,添加那个优先级
@@ -123,9 +125,9 @@ def sample(self, n):
123125
for i in range(n):
124126
a, b = pri_seg * i, pri_seg * (i + 1)
125127
v = np.random.uniform(a, b)
126-
idx, p, data = self.tree.get_leaf(v)
127-
prob = p / self.tree.total_p
128-
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
128+
idx, p, data = self.tree.get_leaf(v) #获取叶子节点
129+
prob = p / self.tree.total_p #给出一个比例
130+
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta) #权重计算IS(importanc sampling) wj = (N·P(j))/maxiwi
129131
b_idx[i], b_memory[i, :] = idx, data
130132
return b_idx, b_memory, ISWeights
131133

0 commit comments

Comments
 (0)