Skip to content

Commit 3310e20

Browse files
committed
Fix: IQN
1 parent 390a9e3 commit 3310e20

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

distributional/2-IQN/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
from config import batch_size, gamma, quantile_embedding_dim, num_tau_sample, num_tau_prime_sample, num_quantile_sample
77

8-
class QRDQN(nn.Module):
8+
class IQN(nn.Module):
99
def __init__(self, num_inputs, num_outputs):
10-
super(QRDQN, self).__init__()
10+
super(IQN, self).__init__()
1111
self.num_inputs = num_inputs
1212
self.num_outputs = num_outputs
1313

distributional/2-IQN/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.nn.functional as F
1111
from tensorboardX import SummaryWriter
1212

13-
from model import QRDQN
13+
from model import IQN
1414
from memory import Memory
1515

1616
from config import env_name, initial_exploration, batch_size, update_target, goal_score, log_interval, device, replay_memory_capacity, lr
@@ -37,8 +37,8 @@ def main():
3737
print('state size:', num_inputs)
3838
print('action size:', num_actions)
3939

40-
online_net = QRDQN(num_inputs, num_actions)
41-
target_net = QRDQN(num_inputs, num_actions)
40+
online_net = IQN(num_inputs, num_actions)
41+
target_net = IQN(num_inputs, num_actions)
4242
update_target_model(online_net, target_net)
4343

4444
optimizer = optim.Adam(online_net.parameters(), lr=lr)
@@ -82,7 +82,7 @@ def main():
8282
epsilon = max(epsilon, 0.1)
8383

8484
batch = memory.sample(batch_size)
85-
loss = QRDQN.train_model(online_net, target_net, optimizer, batch)
85+
loss = IQN.train_model(online_net, target_net, optimizer, batch)
8686

8787
if steps % update_target == 0:
8888
update_target_model(online_net, target_net)

0 commit comments

Comments
 (0)