My implementation of Prioritized Sequence Experience Replay. The paper link is: https://arxiv.org/pdf/1905.12726.pdf
there are output plots: episodes_rewards & step_rewards
I learnt this data structure from: http://www.sefidian.com/2022/11/09/sumtree-data-structure-for-prioritized-experience-replay-per-explained-with-python-code/
# The ‘sum-tree’ data structure used here is very similar in spirit to the array representation
# of a binary heap. However, instead of the usual heap property, the value of a parent node is
# the sum of its children. Leaf nodes store the transition priorities and the internal nodes are
# intermediate sums, with the parent node containing the sum over all priorities, p_total. This
# provides a efficient way of calculating the cumulative sum of priorities, allowing O(log N) updates
# and sampling. (Appendix B.2.1, Proportional prioritization)
# Additional useful links
# Good tutorial about SumTree data structure: https://adventuresinmachinelearning.com/sumtree-introduction-python/
# How to represent full binary tree as array: https://stackoverflow.com/questions/8256222/binary-tree-represented-using-array
class SumTree:
def __init__(self, size):
self.nodes = [0] * (2 * size - 1)
self.data = [None] * size
self.size = size
self.count = 0
self.real_size = 0
@property
def total(self):
return self.nodes[0]
def propagate(self, idx, delta_value):
parent = (idx - 1) // 2
while parent >= 0:
self.nodes[parent] += delta_value
parent = (parent - 1) // 2
def update(self, data_idx, value):
idx = data_idx + self.size - 1 # child index in tree array
delta_value = value - self.nodes[idx]
self.nodes[idx] = value
self.propagate(idx, delta_value)
def add(self, value, data):
self.data[self.count] = data
self.update(self.count, value)
self.count = (self.count + 1) % self.size
self.real_size = min(self.size, self.real_size + 1)
def get(self, cumsum):
assert cumsum <= self.total
idx = 0
while 2 * idx + 1 < len(self.nodes):
left, right = 2*idx + 1, 2*idx + 2
if cumsum <= self.nodes[left]:
idx = left
else:
idx = right
cumsum = cumsum - self.nodes[left]
data_idx = idx - self.size + 1
return data_idx, self.nodes[idx], self.data[data_idx]
def get_priority(self, data_idx):
tree_idx = data_idx + self.size - 1
return self.nodes[tree_idx]
def __repr__(self):
return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"
# Test the sum tree
if __name__ == '__main__':
# Assuming the SumTree class definition is available
# Function to print the state of the tree for easier debugging
def print_tree(tree):
print("Tree Total:", tree.total)
print("Tree Nodes:", tree.nodes)
print("Tree Data:", tree.data)
print()
# Create a SumTree instance
tree_size = 5
tree = SumTree(tree_size)
# Add some data with initial priorities
print("Adding data to the tree...")
for i in range(tree_size):
data = f"Data-{i}"
priority = i + 1 # Priority is just a simple increasing number for this test
tree.add(priority, data)
print_tree(tree)
# Update priority of a data item
print("Updating priority...")
update_index = 2 # For example, update the priority of the third item
new_priority = 10
tree.update(update_index, new_priority)
print_tree(tree)
# Retrieve data based on cumulative sum
print("Retrieving data based on cumulative sum...")
cumulative_sums = [5, 15, 20] # Test with different cumulative sums
for cumsum in cumulative_sums:
idx, node_value, data = tree.get(cumsum)
print(f"Cumulative Sum: {cumsum} -> Retrieved: {data} with Priority: {node_value}")
print()

