Skip to content

Commit

Permalink
bipedal walker update
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilbarhate99 committed Jul 28, 2019
1 parent c997c1b commit 1ed2854
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions PPO_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ def clear_memory(self):
del self.rewards[:]

class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, n_var, action_std):
def __init__(self, state_dim, action_dim, action_std):
super(ActorCritic, self).__init__()
# action mean range -1 to 1
self.actor = nn.Sequential(
nn.Linear(state_dim, n_var),
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(n_var, n_var),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(n_var, action_dim),
nn.Linear(32, action_dim),
nn.Tanh()
)
# critic
self.critic = nn.Sequential(
nn.Linear(state_dim, n_var),
nn.Linear(state_dim, 64),
nn.Tanh(),
nn.Linear(n_var, n_var),
nn.Linear(64, 32),
nn.Tanh(),
nn.Linear(n_var, 1)
nn.Linear(32, 1)
)
self.action_var = torch.full((action_dim,), action_std*action_std).to(device)

Expand Down Expand Up @@ -73,17 +73,16 @@ def evaluate(self, state, action):
return action_logprobs, torch.squeeze(state_value), dist_entropy

class PPO:
def __init__(self, state_dim, action_dim, n_latent_var, action_std, lr, betas, gamma, K_epochs, eps_clip):
def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
self.lr = lr
self.betas = betas
self.gamma = gamma
self.eps_clip = eps_clip
self.K_epochs = K_epochs

self.policy = ActorCritic(state_dim, action_dim, n_latent_var, action_std).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(),
lr=lr, betas=betas)
self.policy_old = ActorCritic(state_dim, action_dim, n_latent_var, action_std).to(device)
self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)

self.MseLoss = nn.MSELoss()

Expand Down Expand Up @@ -132,21 +131,20 @@ def update(self, memory):

def main():
############## Hyperparameters ##############
env_name = "LunarLanderContinuous-v2"
env_name = "BipedalWalker-v2"
render = False
solved_reward = 200 # stop training if avg_reward > solved_reward
solved_reward = 300 # stop training if avg_reward > solved_reward
log_interval = 20 # print avg reward in the interval
max_episodes = 10000 # max training episodes
max_timesteps = 500 # max timesteps in one episode
max_timesteps = 1500 # max timesteps in one episode

update_timestep = 4000 # update policy every n timesteps
action_std = 0.8 # constant std for action distribution (Multivariate Normal)
K_epochs = 100 # update policy for K epochs
action_std = 0.5 # constant std for action distribution (Multivariate Normal)
K_epochs = 80 # update policy for K epochs
eps_clip = 0.2 # clip parameter for PPO
gamma = 0.99 # discount factor

n_latent_var = 64 # number of variables in hidden layer
lr = 0.00025 # parameters for Adam optimizer
lr = 0.0003 # parameters for Adam optimizer
betas = (0.9, 0.999)

random_seed = None
Expand All @@ -164,7 +162,7 @@ def main():
np.random.seed(random_seed)

memory = Memory()
ppo = PPO(state_dim, action_dim, n_latent_var, action_std, lr, betas, gamma, K_epochs, eps_clip)
ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
print(lr,betas)

# logging variables
Expand Down Expand Up @@ -196,11 +194,16 @@ def main():

avg_length += t

# # stop training if avg_reward > solved_reward
# stop training if avg_reward > solved_reward
if running_reward > (log_interval*solved_reward):
print("########## Solved! ##########")
torch.save(ppo.policy.state_dict(), './PPO_Continuous_{}.pth'.format(env_name))
torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
break

# save every 500 episodes
if i_episode % 500 == 0:
torch.save(ppo.policy.state_dict(), './PPO_continuous_{}.pth'.format(env_name))

# logging
if i_episode % log_interval == 0:
avg_length = int(avg_length/log_interval)
Expand All @@ -213,8 +216,3 @@ def main():
if __name__ == '__main__':
main()






0 comments on commit 1ed2854

Please sign in to comment.