Skip to content

Fix ACER #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions rl_algorithms/acer/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import torch
from torch.distributions import Categorical
import torch.nn.functional as F
import wandb

from rl_algorithms.acer.buffer import ReplayMemory
Expand Down Expand Up @@ -85,15 +84,17 @@ def __init__(
self.memory = ReplayMemory(
self.hyper_params.buffer_size, self.hyper_params.n_rollout
)
self.transition = []

def select_action(self, state: np.ndarray) -> Tuple[int, torch.Tensor]:
"""Select action from input space."""
state = numpy2floattensor(state, self.learner.device)
with torch.no_grad():
prob = F.softmax(self.learner.actor_target(state).squeeze(), 0) + 1e-8
action_dist = Categorical(prob)
logits = self.learner.actor_target(state)
action_dist = Categorical(logits=logits)
selected_action = action_dist.sample().item()
return selected_action, prob.cpu().numpy()
self.transition = [action_dist.probs.cpu().numpy()]
return selected_action

def step(self, action: int) -> Tuple[np.ndarray, np.float64, bool, dict]:
"""Take an action and return the reponse of the env"""
Expand Down Expand Up @@ -130,12 +131,12 @@ def train(self):
if self.is_render and self.i_episode >= self.render_after:
self.env.render()

action, prob = self.select_action(state)
action = self.select_action(state)
next_state, reward, done, _ = self.step(action)
done_mask = 0.0 if done else 1.0
self.episode_step += 1
transition = (state, action, reward / 100.0, prob, done_mask)
seq_data.append(transition)
self.transition.extend((state, action, reward / 100.0, done_mask))
seq_data.append(self.transition)
state = next_state
score += reward
if done:
Expand All @@ -157,6 +158,7 @@ def train(self):

if self.i_episode % self.save_period == 0:
self.learner.save_params(self.i_episode)
self.interim_test()

self.env.close()
self.learner.save_params(self.i_episode)
4 changes: 2 additions & 2 deletions rl_algorithms/acer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def add(self, seq_data: list):
If the buffer is empty, it is respectively initialized by size of arguments.
"""
if self.num_in_buffer == 0:
state, action, reward, prob, done_mask = seq_data[0]
prob, state, action, reward, done_mask = seq_data[0]
self._initialize_buffers(state, prob)

self.idx = (self.idx + 1) % (self.buffer_size - 1)

for i, transition in enumerate(seq_data):
state, action, reward, prob, done_mask = transition
prob, state, action, reward, done_mask = transition
self.obs_buf[self.idx][i] = state
self.acts_buf[self.idx][i] = action
self.rews_buf[self.idx][i] = reward
Expand Down
54 changes: 40 additions & 14 deletions rl_algorithms/acer/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rl_algorithms.common.abstract.learner import Learner
import rl_algorithms.common.helper_functions as common_utils
from rl_algorithms.common.networks.brain import Brain
from rl_algorithms.registry import LEARNERS
from rl_algorithms.registry import LEARNERS, build_backbone
from rl_algorithms.utils.config import ConfigDict


Expand Down Expand Up @@ -55,10 +55,39 @@ def __init__(

def _init_network(self):
"""Initialize network and optimizer."""
self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device)
self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
self.device
)
if self.backbone_cfg.shared_actor_critic:
shared_backbone = build_backbone(self.backbone_cfg.shared_actor_critic)
self.actor = Brain(
self.backbone_cfg.shared_actor_critic,
self.head_cfg.actor,
shared_backbone,
)
self.critic = Brain(
self.backbone_cfg.shared_actor_critic,
self.head_cfg.critic,
shared_backbone,
)
self.actor_target = Brain(
self.backbone_cfg.shared_actor_critic,
self.head_cfg.actor,
shared_backbone,
)

else:
self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(
self.device
)
self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
self.device
)
self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(
self.device
)
self.actor = self.actor.to(self.device)
self.actor_target = self.actor_target.to(self.device)
self.critic = self.critic.to(self.device)

self.actor_target.load_state_dict(self.actor.state_dict())
# create optimizer
self.actor_optim = optim.Adam(
self.actor.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
Expand All @@ -67,11 +96,6 @@ def _init_network(self):
self.critic.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
)

self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(
self.device
)
self.actor_target.load_state_dict(self.actor.state_dict())

if self.load_from is not None:
self.load_params(self.load_from)

Expand All @@ -85,26 +109,28 @@ def update_model(self, experience: Tuple) -> torch.Tensor:
done = done.to(self.device)

pi = F.softmax(self.actor(state), 1)
log_pi = torch.log(pi + 1e-8)

q = self.critic(state)
q_i = q.gather(1, action)
pi_i = pi.gather(1, action)
log_pi_i = torch.log(pi_i + 1e-8)

with torch.no_grad():
v = (q * pi).sum(1).unsqueeze(1)
rho = pi / (prob + 1e-8)
rho = torch.exp(log_pi - torch.log(prob + 1e-8))
rho_i = rho.gather(1, action)
rho_bar = rho_i.clamp(max=self.hyper_params.c)

q_ret = self.q_retrace(
reward, done, q_i, v, rho_bar, self.hyper_params.gamma
).to(self.device)

loss_f = -rho_bar * torch.log(pi_i + 1e-8) * (q_ret - v)
loss_f = -rho_bar * log_pi_i * (q_ret - v)
loss_bc = (
-(1 - (self.hyper_params.c / rho)).clamp(min=0)
* pi.detach()
* torch.log(pi + 1e-8)
* log_pi
* (q.detach() - v)
)

Expand All @@ -114,7 +140,7 @@ def update_model(self, experience: Tuple) -> torch.Tensor:
g = loss_f + loss_bc
pi_target = F.softmax(self.actor_target(state), 1)
# gradient of partial Q KL(P || Q) = - P / Q
k = -pi_target / (pi + 1e-8)
k = -(torch.exp(torch.log(pi_target + 1e-8) - (log_pi)))
k_dot_g = k * g
tr = (
g
Expand Down