10
10
import torch .nn .functional as F
11
11
from tensorboardX import SummaryWriter
12
12
13
- from model import QRDQN
13
+ from model import IQN
14
14
from memory import Memory
15
15
16
16
from config import env_name , initial_exploration , batch_size , update_target , goal_score , log_interval , device , replay_memory_capacity , lr
@@ -37,8 +37,8 @@ def main():
37
37
print ('state size:' , num_inputs )
38
38
print ('action size:' , num_actions )
39
39
40
- online_net = QRDQN (num_inputs , num_actions )
41
- target_net = QRDQN (num_inputs , num_actions )
40
+ online_net = IQN (num_inputs , num_actions )
41
+ target_net = IQN (num_inputs , num_actions )
42
42
update_target_model (online_net , target_net )
43
43
44
44
optimizer = optim .Adam (online_net .parameters (), lr = lr )
@@ -82,7 +82,7 @@ def main():
82
82
epsilon = max (epsilon , 0.1 )
83
83
84
84
batch = memory .sample (batch_size )
85
- loss = QRDQN .train_model (online_net , target_net , optimizer , batch )
85
+ loss = IQN .train_model (online_net , target_net , optimizer , batch )
86
86
87
87
if steps % update_target == 0 :
88
88
update_target_model (online_net , target_net )
0 commit comments