14
14
from config import env_name , initial_exploration , batch_size , update_target , goal_score , log_interval , device , replay_memory_capacity , lr
15
15
16
16
17
+ def get_action (state , target_net , epsilon , env ):
18
+ if np .random .rand () <= epsilon :
19
+ return env .action_space .sample ()
20
+ else :
21
+ return target_net .get_action (state )
17
22
18
23
def update_target_model (online_net , target_net ):
19
24
# Target <- Net
@@ -43,6 +48,7 @@ def main():
43
48
target_net .train ()
44
49
memory = Memory (replay_memory_capacity )
45
50
running_score = 0
51
+ epsilon = 1.0
46
52
steps = 0
47
53
loss = 0
48
54
@@ -55,10 +61,9 @@ def main():
55
61
state = state .unsqueeze (0 )
56
62
57
63
while not done :
58
-
59
64
steps += 1
60
65
61
- action = target_net . get_action (state )
66
+ action = get_action (state , target_net , epsilon , env )
62
67
next_state , reward , done , _ = env .step (action )
63
68
64
69
next_state = torch .Tensor (next_state )
@@ -74,6 +79,8 @@ def main():
74
79
state = next_state
75
80
76
81
if steps > initial_exploration :
82
+ epsilon -= 0.00005
83
+ epsilon = max (epsilon , 0.1 )
77
84
78
85
batch = memory .sample (batch_size )
79
86
loss = QNet .train_model (online_net , target_net , optimizer , batch )
@@ -84,8 +91,8 @@ def main():
84
91
score = score if score == 500.0 else score + 1
85
92
running_score = 0.99 * running_score + 0.01 * score
86
93
if e % log_interval == 0 :
87
- print ('{} episode | score: {:.2f} ' .format (
88
- e , running_score ))
94
+ print ('{} episode | score: {:.2f} | epsilon: {:.2f} ' .format (
95
+ e , running_score , epsilon ))
89
96
writer .add_scalar ('log/score' , float (running_score ), e )
90
97
writer .add_scalar ('log/loss' , float (loss ), e )
91
98
0 commit comments