Skip to content

Add R2D1 agents #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
be144ba
Add R2D1 DQNAgent
jinPrelude May 6, 2020
720926e
Fix descriptions & Rename R2D1 to R2D1DQN
jinPrelude May 25, 2020
10393f1
Rebase commit to a302479
jinPrelude Jun 11, 2020
22bc2f0
Rebase commit to 990c78a
jinPrelude Jun 11, 2020
f96e963
Fix several issues commented
jinPrelude Jun 11, 2020
a5df8f1
Delete PrioritizedRecurrentReplayBuffer
jinPrelude Jun 11, 2020
264e611
Resolved issues commented
jinPrelude Jun 11, 2020
cfd194e
Fix issues commented
jinPrelude Jun 11, 2020
8e428e0
Change R2D1Learner parent class & Add descriptions
jinPrelude Jun 11, 2020
5826c01
Fix issues commented
jinPrelude Jun 15, 2020
d58a3b2
Change descriptions
jinPrelude Jun 15, 2020
8add670
Delete unnecessary configs
jinPrelude Jun 15, 2020
ce8a112
Change descriptions due to the length limit
jinPrelude Jun 15, 2020
4b9b1b8
Fix several issues commented
jinPrelude Jun 16, 2020
1be2c3f
Fix contiguous issues
jinPrelude Jun 17, 2020
eb231e4
Add __init__.py in recurrent
jinPrelude Jun 17, 2020
bb37068
Rebase commit to 815a1ca
jinPrelude Jun 18, 2020
4066dbf
Use no grad tensor to select action
jinPrelude Jun 22, 2020
a3725c1
Modify documentation
Jun 22, 2020
b7f8201
Add torch.no_grad to select action in other class
Jun 22, 2020
1d3cf20
Add R2D1DQN ResNet config
jinPrelude Jun 23, 2020
d3a2e59
Fix R2D1DQN ResNet config
jinPrelude Jun 23, 2020
083a447
Merge recurrent_replay_buffer into replay_bufer
jinPrelude Jun 23, 2020
c3056b3
Add R2D1 on readme
jinPrelude Jun 23, 2020
b0ed25a
Remove off-framestack explanation
jinPrelude Jun 23, 2020
abcc34b
Change r2d1 configs' framestack 1 to 4
jinPrelude Jun 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add torch.no_grad to select action in other class
  • Loading branch information
khkim committed Jun 22, 2020
commit b7f82016da1f98d96b10c9045e798da9d19b6490
1 change: 1 addition & 0 deletions rl_algorithms/common/abstract/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def set_wandb(self):
name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
)
wandb.config.update(vars(self.args))
wandb.config.update(self.hyper_params)
shutil.copy(self.args.cfg_path, os.path.join(wandb.run.dir, "config.py"))

def interim_test(self):
Expand Down
10 changes: 0 additions & 10 deletions rl_algorithms/common/networks/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,6 @@ def __init__(
self.action_size = head_cfg.configs.output_size
"""Initialize. Generate different structure whether it has CNN module or not."""
Brain.__init__(self, backbone_cfg, head_cfg)
if not backbone_cfg:
self.backbone = identity
head_cfg.configs.input_size = head_cfg.configs.state_size[0]

else:
self.backbone = build_backbone(backbone_cfg)
head_cfg.configs.input_size = self.calculate_fc_input_size(
head_cfg.configs.state_size
)
self.fc = nn.Linear(head_cfg.configs.input_size, gru_cfg.rnn_hidden_size,)
self.gru = nn.GRU(
gru_cfg.rnn_hidden_size + self.action_size + 1, # 1 is for prev_reward
Expand Down Expand Up @@ -131,7 +122,6 @@ def forward(
dim=2,
)
hidden = torch.transpose(hidden, 0, 1)
hidden = None if hidden is None else hidden

# Unroll gru
gru_out, hidden = self.gru(gru_input, hidden)
Expand Down
3 changes: 2 additions & 1 deletion rl_algorithms/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
):
return np.array(self.env_info.action_space.sample())

selected_action = self.learner.actor(state).detach().cpu().numpy()
with torch.no_grad():
selected_action = self.learner.actor(state).detach().cpu().numpy()

if not self.args.test:
noise = self.noise.sample()
Expand Down
5 changes: 3 additions & 2 deletions rl_algorithms/dqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def select_action(self, state: np.ndarray) -> np.ndarray:
if not self.args.test and self.epsilon > np.random.random():
selected_action = np.array(self.env.action_space.sample())
else:
state = self._preprocess_state(state)
selected_action = self.learner.dqn(state).argmax()
with torch.no_grad():
state = self._preprocess_state(state)
selected_action = self.learner.dqn(state).argmax()
selected_action = selected_action.detach().cpu().numpy()
return selected_action

Expand Down
49 changes: 25 additions & 24 deletions rl_algorithms/dqn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,38 @@ def __call__(
gamma_with_terminal = gamma_with_terminal.repeat(
head_cfg.configs.n_tau_prime_samples, 1
)
with torch.no_grad():
# Get the indices of the maximium Q-value across the action dimension.
# Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1.
next_actions = model(next_states).argmax(dim=1) # double Q
next_actions = next_actions[:, None]
next_actions = next_actions.repeat(head_cfg.configs.n_tau_prime_samples, 1)

# Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1.
target_quantile_values, _ = target_model.forward_(
next_states, head_cfg.configs.n_tau_prime_samples
)
target_quantile_values = target_quantile_values.gather(1, next_actions)
target_quantile_values = (
rewards + gamma_with_terminal * target_quantile_values
)
target_quantile_values = target_quantile_values.detach()

# Get the indices of the maximium Q-value across the action dimension.
# Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1.
next_actions = model(next_states).argmax(dim=1) # double Q
next_actions = next_actions[:, None]
next_actions = next_actions.repeat(head_cfg.configs.n_tau_prime_samples, 1)

# Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1.
target_quantile_values, _ = target_model.forward_(
next_states, head_cfg.configs.n_tau_prime_samples
)
target_quantile_values = target_quantile_values.gather(1, next_actions)
target_quantile_values = rewards + gamma_with_terminal * target_quantile_values
target_quantile_values = target_quantile_values.detach()

# Reshape to n_tau_prime_samples x batch_size x 1 since this is
# the manner in which the target_quantile_values are tiled.
target_quantile_values = target_quantile_values.view(
head_cfg.configs.n_tau_prime_samples, batch_size, 1
)
# Reshape to n_tau_prime_samples x batch_size x 1 since this is
# the manner in which the target_quantile_values are tiled.
target_quantile_values = target_quantile_values.view(
head_cfg.configs.n_tau_prime_samples, batch_size, 1
)

# Transpose dimensions so that the dimensionality is batch_size x
# n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
target_quantile_values = torch.transpose(target_quantile_values, 0, 1)
# Transpose dimensions so that the dimensionality is batch_size x
# n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
target_quantile_values = torch.transpose(target_quantile_values, 0, 1)

# Get quantile values: (n_tau_samples x batch_size) x action_dim.
quantile_values, quantiles = model.forward_(
states, head_cfg.configs.n_tau_samples
)

reshaped_actions = actions[:, None].repeat(head_cfg.configs.n_tau_samples, 1)
chosen_action_quantile_values = quantile_values.gather(
1, reshaped_actions.long()
Expand Down Expand Up @@ -121,7 +124,6 @@ def __call__(
quantiles = quantiles[:, None, :, :].repeat(
1, head_cfg.configs.n_tau_prime_samples, 1, 1
)
quantiles = quantiles.to(device)

# Shape: batch_size x n_tau_prime_samples x n_tau_samples x 1.
quantile_huber_loss = (
Expand Down Expand Up @@ -236,7 +238,6 @@ def __call__(
# = r otherwise
masks = 1 - dones
target = rewards + gamma * next_q_value * masks
target = target.to(device)

# calculate dq loss
dq_loss_element_wise = F.smooth_l1_loss(
Expand Down
2 changes: 1 addition & 1 deletion rl_algorithms/dqn/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward_(

quantile_values = super(IQNMLP, self).forward(quantile_net)

return quantile_values, quantiles
return quantile_values, quantiles.to(device)

def forward(self, state: torch.Tensor) -> torch.Tensor:
"""Forward method implementation."""
Expand Down
1 change: 1 addition & 0 deletions rl_algorithms/recurrent/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def select_action(
state, hidden_state, prev_action, prev_reward
)
selected_action = selected_action.detach().argmax().cpu().numpy()

if not self.args.test and self.epsilon > np.random.random():
selected_action = np.array(self.env.action_space.sample())
return selected_action, hidden_state
Expand Down
10 changes: 8 additions & 2 deletions rl_algorithms/recurrent/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ def update_model(

weights, indices = experience_1[-3:-1]
gamma = self.hyper_params.gamma
burn_in_step = self.gru_cfg.burn_in_step

dq_loss_element_wise, q_values = self.loss_fn(
self.dqn, self.dqn_target, experience_1, gamma, self.head_cfg
self.dqn, self.dqn_target, experience_1, gamma, self.head_cfg, burn_in_step
)
dq_loss = torch.mean(dq_loss_element_wise * weights)

Expand All @@ -100,7 +101,12 @@ def update_model(
gamma = self.hyper_params.gamma ** self.hyper_params.n_step

dq_loss_n_element_wise, q_values_n = self.loss_fn(
self.dqn, self.dqn_target, experience_n, gamma, self.head_cfg
self.dqn,
self.dqn_target,
experience_n,
gamma,
self.head_cfg,
burn_in_step,
)

# to update loss and priorities
Expand Down
89 changes: 47 additions & 42 deletions rl_algorithms/recurrent/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def __call__(
experiences: Tuple[torch.Tensor, ...],
gamma: float,
head_cfg: ConfigDict,
burn_in_step: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return R2D1DQN loss and Q-values.
TODO: Combine with DQNLoss
"""

"""Return R2D1DQN loss and Q-values."""
# TODO: Combine with DQNLoss
output_size = head_cfg.configs.output_size
(
burnin_states_tuple,
states_tuple,
Expand All @@ -46,7 +46,7 @@ def __call__(
burnin_dones_tuple,
agent_dones,
init_rnn_state,
) = slice_r2d1_arguments(experiences, head_cfg)
) = slice_r2d1_arguments(experiences, burn_in_step, output_size)

with torch.no_grad():
_, target_rnn_state = target_model(
Expand All @@ -72,33 +72,37 @@ def __call__(
init_rnn_state[burnin_invalid_mask] = 0
target_rnn_state[burnin_target_invalid_mask] = 0

qs, _ = model(
q_values, _ = model(
states_tuple[0],
init_rnn_state,
prev_actions_tuple[0],
prev_rewards_tuple[0],
)
q = qs.gather(-1, agent_actions)
q_value = q_values.gather(-1, agent_actions)

with torch.no_grad():
target_qs, _ = target_model(
target_q_values, _ = target_model(
states_tuple[1],
target_rnn_state,
prev_actions_tuple[1],
prev_rewards_tuple[1],
)
next_qs, _ = model(
next_q_values, _ = model(
states_tuple[1],
target_rnn_state,
prev_actions_tuple[0],
prev_rewards_tuple[0],
)
next_a = torch.argmax(next_qs, dim=-1)
target_q = target_qs.gather(-1, next_a.unsqueeze(-1))
next_action = torch.argmax(next_q_values, dim=-1)
target_q_value = target_q_values.gather(-1, next_action.unsqueeze(-1))

target = agent_rewards + gamma * target_q * (1 - agent_dones)
dq_loss_element_wise = F.smooth_l1_loss(q, target.detach(), reduction="none")
target = agent_rewards + gamma * target_q_value * (1 - agent_dones)
dq_loss_element_wise = F.smooth_l1_loss(
q_value, target.detach(), reduction="none"
)
delta = abs(torch.mean(dq_loss_element_wise, dim=1))
return delta, q

return delta, q_value


@LOSSES.register_module
Expand All @@ -110,10 +114,11 @@ def __call__(
experiences: Tuple[torch.Tensor, ...],
gamma: float,
head_cfg: ConfigDict,
burn_in_step: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return element-wise C51 loss and Q-values.
TODO: Combine with C51Loss
"""
"""Return element-wise C51 loss and Q-values."""
# TODO: Combine with IQNLoss
output_size = head_cfg.configs.output_size
(
burnin_states_tuple,
states_tuple,
Expand All @@ -126,7 +131,7 @@ def __call__(
burnin_dones_tuple,
agent_dones,
init_rnn_state,
) = slice_r2d1_arguments(experiences, head_cfg)
) = slice_r2d1_arguments(experiences, burn_in_step, output_size)

batch_size = states_tuple[0].shape[0]
sequence_size = states_tuple[0].shape[1]
Expand Down Expand Up @@ -235,10 +240,11 @@ def __call__(
experiences: Tuple[torch.Tensor, ...],
gamma: float,
head_cfg: ConfigDict,
burn_in_step: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return R2D1 loss and Q-values.
TODO: Combine with IQNLoss
"""
"""Return R2D1 loss and Q-values."""
# TODO: Combine with IQNLoss
output_size = head_cfg.configs.output_size
(
burnin_states_tuple,
states_tuple,
Expand All @@ -251,7 +257,7 @@ def __call__(
burnin_dones_tuple,
agent_dones,
init_rnn_state,
) = slice_r2d1_arguments(experiences, head_cfg)
) = slice_r2d1_arguments(experiences, burn_in_step, output_size)

batch_size = states_tuple[0].shape[0]
sequence_size = states_tuple[0].shape[1]
Expand Down Expand Up @@ -289,18 +295,17 @@ def __call__(
gamma_with_terminal = gamma_with_terminal.repeat(
head_cfg.configs.n_tau_prime_samples, 1, 1
)

# Get the indices of the maximium Q-value across the action dimension.
# Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1.
qs, _ = model(
next_actions, _ = model(
states_tuple[1],
target_rnn_state,
prev_actions_tuple[1],
prev_rewards_tuple[1],
)
qs = qs.argmax(dim=-1)
qs = qs[:, :, None]
qs = qs.repeat(head_cfg.configs.n_tau_prime_samples, 1, 1)
).argmax(dim=-1)
next_actions = next_actions[:, :, None]
next_actions = next_actions.repeat(head_cfg.configs.n_tau_prime_samples, 1, 1)

with torch.no_grad():
# Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1.
target_quantile_values, _, _ = target_model.forward_(
Expand All @@ -310,21 +315,21 @@ def __call__(
prev_rewards_tuple[1],
head_cfg.configs.n_tau_prime_samples,
)
target_quantile_values = target_quantile_values.gather(-1, qs)
target_quantile_values = (
agent_rewards + gamma_with_terminal * target_quantile_values
)
target_quantile_values = target_quantile_values.detach()
target_quantile_values = target_quantile_values.gather(-1, next_actions)
target_quantile_values = (
agent_rewards + gamma_with_terminal * target_quantile_values
)
target_quantile_values = target_quantile_values.detach()

# Reshape to n_tau_prime_samples x batch_size x 1 since this is
# the manner in which the target_quantile_values are tiled.
target_quantile_values = target_quantile_values.view(
head_cfg.configs.n_tau_prime_samples, batch_size, sequence_size, 1
)
# Reshape to n_tau_prime_samples x batch_size x 1 since this is
# the manner in which the target_quantile_values are tiled.
target_quantile_values = target_quantile_values.view(
head_cfg.configs.n_tau_prime_samples, batch_size, sequence_size, 1
)

# Transpose dimensions so that the dimensionality is batch_size x
# n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
target_quantile_values = torch.transpose(target_quantile_values, 0, 1)
# Transpose dimensions so that the dimensionality is batch_size x
# n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
target_quantile_values = torch.transpose(target_quantile_values, 0, 1)

# Get quantile values: (n_tau_samples x batch_size) x action_dim.
quantile_values, quantiles, _ = model.forward_(
Expand All @@ -334,6 +339,7 @@ def __call__(
prev_rewards_tuple[0],
head_cfg.configs.n_tau_samples,
)

reshaped_actions = agent_actions.repeat(head_cfg.configs.n_tau_samples, 1, 1)
chosen_action_quantile_values = quantile_values.gather(
-1, reshaped_actions.long()
Expand Down Expand Up @@ -383,7 +389,6 @@ def __call__(
quantiles = quantiles[:, None, :, :, :].repeat(
1, head_cfg.configs.n_tau_prime_samples, 1, 1, 1
)
quantiles = quantiles.to(device)

# Shape: batch_size x n_tau_prime_samples x n_tau_samples x sequence_length x 1.
quantile_huber_loss = (
Expand Down
Loading