@@ -33,10 +33,10 @@ class SumTree(object):
33
33
34
34
def __init__ (self , capacity ):
35
35
self .capacity = capacity # for all priority values
36
- self .tree = np .zeros (2 * capacity - 1 )
36
+ self .tree = np .zeros (2 * capacity - 1 ) #给定树的大小,一个sumTree的点和
37
37
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
38
38
# size: capacity - 1 size: capacity
39
- self .data = np .zeros (capacity , dtype = object ) # for all transitions
39
+ self .data = np .zeros (capacity , dtype = object ) # for all transitions 都记录到了叶子节点中
40
40
# [--------------data frame-------------]
41
41
# size: capacity
42
42
@@ -45,7 +45,9 @@ def add(self, p, data):
45
45
self .data [self .data_pointer ] = data # update data_frame
46
46
self .update (tree_idx , p ) # update tree_frame
47
47
48
- self .data_pointer += 1
48
+ self .data_pointer += 1 # 数据点增加
49
+
50
+ #完成一轮后要重新开始添到capacity对应的节点中
49
51
if self .data_pointer >= self .capacity : # replace when exceed the capacity
50
52
self .data_pointer = 0
51
53
@@ -55,7 +57,7 @@ def update(self, tree_idx, p):
55
57
# then propagate the change through tree
56
58
while tree_idx != 0 : # this method is faster than the recursive loop in the reference code
57
59
tree_idx = (tree_idx - 1 ) // 2
58
- self .tree [tree_idx ] += change
60
+ self .tree [tree_idx ] += change #sumTree
59
61
60
62
def get_leaf (self , v ):
61
63
"""
@@ -73,10 +75,10 @@ def get_leaf(self, v):
73
75
"""
74
76
parent_idx = 0
75
77
while True : # the while loop is faster than the method in the reference code
76
- cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
77
- cr_idx = cl_idx + 1
78
+ cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids,如何表示左子节点位置
79
+ cr_idx = cl_idx + 1 #如何表示父节点的子节点右节点位置
78
80
if cl_idx >= len (self .tree ): # reach bottom, end search
79
- leaf_idx = parent_idx
81
+ leaf_idx = parent_idx #到达bottom结束search
80
82
break
81
83
else : # downward search, always search for a higher priority node
82
84
if v <= self .tree [cl_idx ]:
@@ -108,7 +110,7 @@ def __init__(self, capacity):
108
110
self .tree = SumTree (capacity )
109
111
110
112
def store (self , transition ):
111
- max_p = np .max (self .tree .tree [- self .tree .capacity :]) #找到优先级最高的
113
+ max_p = np .max (self .tree .tree [- self .tree .capacity :]) #找到优先级最高的那个位置
112
114
if max_p == 0 : #如果为0,设置一个默认值
113
115
max_p = self .abs_err_upper
114
116
self .tree .add (max_p , transition ) # set the max p for new p,添加那个优先级
@@ -123,9 +125,9 @@ def sample(self, n):
123
125
for i in range (n ):
124
126
a , b = pri_seg * i , pri_seg * (i + 1 )
125
127
v = np .random .uniform (a , b )
126
- idx , p , data = self .tree .get_leaf (v )
127
- prob = p / self .tree .total_p
128
- ISWeights [i , 0 ] = np .power (prob / min_prob , - self .beta )
128
+ idx , p , data = self .tree .get_leaf (v ) #获取叶子节点
129
+ prob = p / self .tree .total_p #给出一个比例
130
+ ISWeights [i , 0 ] = np .power (prob / min_prob , - self .beta ) #权重计算IS(importanc sampling) wj = (N·P(j))/maxiwi
129
131
b_idx [i ], b_memory [i , :] = idx , data
130
132
return b_idx , b_memory , ISWeights
131
133
0 commit comments