1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from collections import namedtuple , deque
5
+
6
+ from config import gamma , batch_size , alpha , beta
7
+
8
+ Transition = namedtuple ('Transition' , ('state' , 'next_state' , 'action' , 'reward' , 'mask' , 'step' ))
9
+
10
+ class N_Step_Buffer (object ):
11
+ def __init__ (self ):
12
+ self .memory = []
13
+ self .step = 0
14
+
15
+ def push (self , state , next_state , action , reward , mask ):
16
+ self .step += 1
17
+ self .memory .append ([state , next_state , action , reward , mask ])
18
+
19
+ def sample (self ):
20
+ [state , _ , action , _ , _ ] = self .memory [0 ]
21
+ [_ , next_state , _ , _ , mask ] = self .memory [- 1 ]
22
+
23
+ sum_reward = 0
24
+ for t in reversed (range (len (self .memory ))):
25
+ [_ , _ , _ , reward , _ ] = self .memory [t ]
26
+ sum_reward += reward + gamma * sum_reward
27
+ reward = sum_reward
28
+ step = self .step
29
+ self .reset ()
30
+
31
+ return [state , next_state , action , reward , mask , step ]
32
+
33
+ def reset (self ):
34
+ self .memory = []
35
+ self .step = 0
36
+
37
+ def __len__ (self ):
38
+ return len (self .memory )
39
+
40
+
41
+ class LocalBuffer (object ):
42
+ def __init__ (self ):
43
+ self .memory = []
44
+
45
+ def push (self , state , next_state , action , reward , mask , step ):
46
+ self .memory .append (Transition (state , next_state , action , reward , mask , step ))
47
+
48
+ def sample (self ):
49
+ transitions = self .memory
50
+ batch = Transition (* zip (* transitions ))
51
+ return batch
52
+
53
+ def reset (self ):
54
+ self .memory = []
55
+
56
+ def __len__ (self ):
57
+ return len (self .memory )
58
+
59
+ class Memory (object ):
60
+ def __init__ (self , capacity ):
61
+ self .capacity = capacity
62
+ self .memory = deque (maxlen = capacity )
63
+ self .memory_probability = deque (maxlen = capacity )
64
+
65
+ def push (self , state , next_state , action , reward , mask , step , prior ):
66
+ self .memory .append (Transition (state , next_state , action , reward , mask , step ))
67
+ self .memory_probability .append (prior )
68
+
69
+ def sample (self ):
70
+ probaility = torch .Tensor (self .memory_probability )
71
+ probaility = probaility .pow (alpha )
72
+ probaility = probaility / probaility .sum ()
73
+
74
+ p = probaility .numpy ()
75
+
76
+ indexes = np .random .choice (range (len (self .memory_probability )), batch_size , p = p )
77
+
78
+ transitions = [self .memory [idx ] for idx in indexes ]
79
+ transitions_p = torch .Tensor ([self .memory_probability [idx ] for idx in indexes ])
80
+
81
+ batch = Transition (* zip (* transitions ))
82
+
83
+ weights = (self .capacity * transitions_p ).pow (- beta )
84
+ weights = weights / weights .max ()
85
+
86
+ return indexes , batch , weights
87
+
88
+ def update_prior (self , indexes , priors ):
89
+ priors_idx = 0
90
+ for idx in indexes :
91
+ self .memory_probability [idx ] = priors [priors_idx ]
92
+ priors_idx += 1
93
+
94
+ def __len__ (self ):
95
+ return len (self .memory )
96
+
97
+
98
+
0 commit comments