Skip to content

Commit

Permalink
a3c in the cartpole
Browse files Browse the repository at this point in the history
  • Loading branch information
philtabor committed Mar 15, 2021
1 parent 733e452 commit 5ce3eb8
Showing 1 changed file with 168 additions and 0 deletions.
168 changes: 168 additions & 0 deletions ReinforcementLearning/PolicyGradient/A3C/pytorch/a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import gym
import torch as T
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

class SharedAdam(T.optim.Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
weight_decay=0):
super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['exp_avg'] = T.zeros_like(p.data)
state['exp_avg_sq'] = T.zeros_like(p.data)

state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()

class ActorCritic(nn.Module):
def __init__(self, input_dims, n_actions, gamma=0.99):
super(ActorCritic, self).__init__()

self.gamma = gamma

self.pi1 = nn.Linear(*input_dims, 128)
self.v1 = nn.Linear(*input_dims, 128)
self.pi = nn.Linear(128, n_actions)
self.v = nn.Linear(128, 1)

self.rewards = []
self.actions = []
self.states = []

def remember(self, state, action, reward):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)

def clear_memory(self):
self.states = []
self.actions = []
self.rewards = []

def forward(self, state):
pi1 = F.relu(self.pi1(state))
v1 = F.relu(self.v1(state))

pi = self.pi(pi1)
v = self.v(v1)

return pi, v

def calc_R(self, done):
states = T.tensor(self.states, dtype=T.float)
_, v = self.forward(states)

R = v[-1]*(1-int(done))

batch_return = []
for reward in self.rewards[::-1]:
R = reward + self.gamma*R
batch_return.append(R)
batch_return.reverse()
batch_return = T.tensor(batch_return, dtype=T.float)

return batch_return

def calc_loss(self, done):
states = T.tensor(self.states, dtype=T.float)
actions = T.tensor(self.actions, dtype=T.float)

returns = self.calc_R(done)

pi, values = self.forward(states)
values = values.squeeze()
critic_loss = (returns-values)**2

probs = T.softmax(pi, dim=1)
dist = Categorical(probs)
log_probs = dist.log_prob(actions)
actor_loss = -log_probs*(returns-values)

total_loss = (critic_loss + actor_loss).mean()

return total_loss

def choose_action(self, observation):
state = T.tensor([observation], dtype=T.float)
pi, v = self.forward(state)
probs = T.softmax(pi, dim=1)
dist = Categorical(probs)
action = dist.sample().numpy()[0]

return action

class Agent(mp.Process):
def __init__(self, global_actor_critic, optimizer, input_dims, n_actions,
gamma, lr, name, global_ep_idx, env_id):
super(Agent, self).__init__()
self.local_actor_critic = ActorCritic(input_dims, n_actions, gamma)
self.global_actor_critic = global_actor_critic
self.name = 'w%02i' % name
self.episode_idx = global_ep_idx
self.env = gym.make(env_id)
self.optimizer = optimizer

def run(self):
t_step = 1
while self.episode_idx.value < N_GAMES:
done = False
observation = self.env.reset()
score = 0
self.local_actor_critic.clear_memory()
while not done:
action = self.local_actor_critic.choose_action(observation)
observation_, reward, done, info = self.env.step(action)
score += reward
self.local_actor_critic.remember(observation, action, reward)
if t_step % T_MAX == 0 or done:
loss = self.local_actor_critic.calc_loss(done)
self.optimizer.zero_grad()
loss.backward()
for local_param, global_param in zip(
self.local_actor_critic.parameters(),
self.global_actor_critic.parameters()):
global_param._grad = local_param.grad
self.optimizer.step()
self.local_actor_critic.load_state_dict(
self.global_actor_critic.state_dict())
self.local_actor_critic.clear_memory()
t_step += 1
observation = observation_
with self.episode_idx.get_lock():
self.episode_idx.value += 1
print(self.name, 'episode ', self.episode_idx.value, 'reward %.1f' % score)

if __name__ == '__main__':
lr = 1e-4
env_id = 'CartPole-v0'
n_actions = 2
input_dims = [4]
N_GAMES = 3000
T_MAX = 5
global_actor_critic = ActorCritic(input_dims, n_actions)
global_actor_critic.share_memory()
optim = SharedAdam(global_actor_critic.parameters(), lr=lr,
betas=(0.92, 0.999))
global_ep = mp.Value('i', 0)

workers = [Agent(global_actor_critic,
optim,
input_dims,
n_actions,
gamma=0.99,
lr=lr,
name=i,
global_ep_idx=global_ep,
env_id=env_id) for i in range(mp.cpu_count())]
[w.start() for w in workers]
[w.join() for w in workers]



0 comments on commit 5ce3eb8

Please sign in to comment.