@@ -80,27 +80,22 @@ def __init__(self, capacity):
80
80
self .memory_probability = deque (maxlen = capacity )
81
81
82
82
def td_error_to_prior (self , td_error , lengths ):
83
- abs_td_error_sum = td_error .sum (dim = 1 , keepdim = True ).view (- 1 ). abs ( ).detach ().numpy ()
83
+ abs_td_error_sum = td_error .abs (). sum (dim = 1 , keepdim = True ).view (- 1 ).detach ().numpy ()
84
84
lengths_burn = [length - burn_in_length for length in lengths ]
85
-
86
- prior = abs_td_error_sum / lengths_burn
87
- return prior
85
+
86
+ prior_max = td_error .abs ().max (dim = 1 , keepdim = True )[0 ].view (- 1 ).detach ().numpy ()
87
+
88
+ prior_mean = abs_td_error_sum / lengths_burn
89
+ prior = eta * prior_max + (1 - eta ) * prior_mean
90
+ return prior
88
91
89
92
def push (self , td_error , batch , lengths ):
90
93
# batch.state[local_mini_batch, sequence_length, item]
91
94
prior = self .td_error_to_prior (td_error , lengths )
92
-
95
+
93
96
for i in range (len (batch )):
94
- if len (self .memory_probability ) > 0 :
95
- memory_probability = np .array (self .memory_probability )
96
- probability_max = max (memory_probability .max (), prior [i ])
97
- probability_mean = (memory_probability .sum () + prior [i ]) / (len (self .memory_probability ) + 1 )
98
- else :
99
- probability_max = prior [i ]
100
- probability_mean = prior [i ]
101
97
self .memory .append ([Transition (batch .state [i ], batch .next_state [i ], batch .action [i ], batch .reward [i ], batch .mask [i ], batch .step [i ], batch .rnn_state [i ]), lengths [i ]])
102
- p = eta * probability_max + (1 - eta ) * probability_mean
103
- self .memory_probability .append (p )
98
+ self .memory_probability .append (prior [i ])
104
99
105
100
def sample (self , batch_size ):
106
101
probability = np .array (self .memory_probability )
0 commit comments