Skip to content

Commit 6e86643

Browse files
committed
Fix: Rainbow Error
1 parent 7c179f3 commit 6e86643

File tree

14 files changed

+87
-63
lines changed

14 files changed

+87
-63
lines changed

rainbow/1-dqn/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.1
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
99
log_interval = 10

rainbow/2-DoubleDQN/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.1
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
9-
log_interval=10
9+
log_interval = 10
1010
update_target = 100
1111
replay_memory_capacity = 1000
1212
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

rainbow/3-DuelDQN/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.1
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
99
log_interval = 10

rainbow/5-per/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.1
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
99
log_interval = 10

rainbow/6-Nosiy_net/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.1
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
99
log_interval = 10

rainbow/6-Nosiy_net/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def train_model(cls, online_net, target_net, optimizer, batch):
6767
rewards = torch.Tensor(batch.reward)
6868
masks = torch.Tensor(batch.mask)
6969

70-
target_net.fc2.reset_noise()
7170
pred = online_net(states).squeeze(1)
7271
next_pred = target_net(next_states).squeeze(1)
7372

@@ -79,11 +78,14 @@ def train_model(cls, online_net, target_net, optimizer, batch):
7978
optimizer.zero_grad()
8079
loss.backward()
8180
optimizer.step()
81+
online_net.reset_noise()
8282

8383
return loss
8484

8585
def get_action(self, input):
86-
self.fc2.reset_noise()
8786
qvalue = self.forward(input)
8887
_, action = torch.max(qvalue, 1)
8988
return action.numpy()[0]
89+
90+
def reset_noise(self):
91+
self.fc2.reset_noise()

rainbow/6-Nosiy_net/train.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from config import env_name, initial_exploration, batch_size, update_target, goal_score, log_interval, device, replay_memory_capacity, lr
1515

1616

17+
def get_action(state, target_net, epsilon, env):
18+
if np.random.rand() <= epsilon:
19+
return env.action_space.sample()
20+
else:
21+
return target_net.get_action(state)
1722

1823
def update_target_model(online_net, target_net):
1924
# Target <- Net
@@ -43,6 +48,7 @@ def main():
4348
target_net.train()
4449
memory = Memory(replay_memory_capacity)
4550
running_score = 0
51+
epsilon = 1.0
4652
steps = 0
4753
loss = 0
4854

@@ -55,10 +61,9 @@ def main():
5561
state = state.unsqueeze(0)
5662

5763
while not done:
58-
5964
steps += 1
6065

61-
action = target_net.get_action(state)
66+
action = get_action(state, target_net, epsilon, env)
6267
next_state, reward, done, _ = env.step(action)
6368

6469
next_state = torch.Tensor(next_state)
@@ -74,6 +79,8 @@ def main():
7479
state = next_state
7580

7681
if steps > initial_exploration:
82+
epsilon -= 0.00005
83+
epsilon = max(epsilon, 0.1)
7784

7885
batch = memory.sample(batch_size)
7986
loss = QNet.train_model(online_net, target_net, optimizer, batch)
@@ -84,8 +91,8 @@ def main():
8491
score = score if score == 500.0 else score + 1
8592
running_score = 0.99 * running_score + 0.01 * score
8693
if e % log_interval == 0:
87-
print('{} episode | score: {:.2f} '.format(
88-
e, running_score))
94+
print('{} episode | score: {:.2f} | epsilon: {:.2f}'.format(
95+
e, running_score, epsilon))
8996
writer.add_scalar('log/score', float(running_score), e)
9097
writer.add_scalar('log/loss', float(loss), e)
9198

rainbow/7-distributional_c51/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def __init__(self, num_inputs, num_outputs):
1616
self.z = torch.Tensor([V_min + i * self.dz for i in range(num_support)])
1717

1818
self.fc1 = nn.Linear(num_inputs, 128)
19-
# self.fc2 = nn.Linear(128, 128)
2019
self.fc2 = nn.Linear(128, num_outputs * num_support)
2120

2221
for m in self.modules():
@@ -82,7 +81,7 @@ def train_model(cls, online_net, target_net, optimizer, batch):
8281
m_prob = cls.get_m(rewards, masks, prob_next_states_action)
8382
m_prob = torch.tensor(m_prob)
8483

85-
m_prob = m_prob / torch.sum(m_prob, dim=1, keepdim=True)
84+
m_prob = (m_prob / torch.sum(m_prob, dim=1, keepdim=True)).detach()
8685
expand_dim_action = torch.unsqueeze(actions, -1)
8786
p = torch.sum(online_net(states) * expand_dim_action.float(), dim=1)
8887
loss = -torch.sum(m_prob * torch.log(p + 1e-20), 1)

rainbow/8-Not_Distributional/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
env_name = 'CartPole-v1'
44
gamma = 0.99
55
batch_size = 32
6-
lr = 0.01
6+
lr = 0.001
77
initial_exploration = 1000
88
goal_score = 200
99
log_interval = 10

rainbow/8-Not_Distributional/memory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def sample(self, batch_size, net, target_net, beta):
6868

6969

7070
td_error = QNet.get_td_error(net, target_net, batch.state, batch.next_state, batch.action, batch.reward, batch.mask)
71+
td_error = td_error.detach()
7172

7273
td_error_idx = 0
7374
for idx in indexes:

0 commit comments

Comments
 (0)