10
10
from model import QNet
11
11
from tensorboardX import SummaryWriter
12
12
13
+ from memory import Memory
13
14
from config import env_name , goal_score , log_interval , device , lr , gamma
14
15
15
16
@@ -24,7 +25,6 @@ def main():
24
25
print ('action size:' , num_actions )
25
26
26
27
net = QNet (num_inputs , num_actions )
27
-
28
28
writer = SummaryWriter ('logs' )
29
29
30
30
net .to (device )
@@ -33,9 +33,9 @@ def main():
33
33
steps = 0
34
34
loss = 0
35
35
k = 0
36
- for e in range (3000 ):
36
+ for e in range (30000 ):
37
37
done = False
38
- memory = []
38
+ memory = Memory ()
39
39
40
40
score = 0
41
41
state = env .reset ()
@@ -56,22 +56,13 @@ def main():
56
56
57
57
action_one_hot = torch .zeros (2 )
58
58
action_one_hot [action ] = 1
59
- memory .append ([ state , next_state , action_one_hot , reward , mask ] )
59
+ memory .push ( state , next_state , action_one_hot , reward , mask )
60
60
61
61
score += reward
62
62
state = next_state
63
63
64
64
sum_reward = 0
65
- memory .reverse ()
66
- states , actions , rewards , masks = [], [], [], []
67
- for t , transition in enumerate (memory ):
68
- state , next_state , action , reward , mask = transition
69
- sum_reward = (reward + gamma * sum_reward )
70
- states .append (state )
71
- actions .append (action )
72
- rewards .append (sum_reward )
73
- masks .append (mask )
74
- loss = QNet .train_model (net , (states , actions , rewards , masks ), k )
65
+ loss = QNet .train_model (net , memory .sample (), k )
75
66
k += 1
76
67
77
68
score = score if score == 500.0 else score + 1
0 commit comments