Skip to content

Commit

Permalink
policy gradients in tensorflow 2
Browse files Browse the repository at this point in the history
  • Loading branch information
philtabor committed Sep 7, 2020
1 parent 9a9d5b9 commit 986c658
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 0 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions ReinforcementLearning/PolicyGradient/reinforce/tensorflow2/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# if you have more than 1 gpu, use device '0' or '1' to assign to a gpu
#import os
#os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
#os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import gym
import numpy as np
from reinforce_tf2 import Agent
from utils import plotLearning

if __name__ == '__main__':
agent = Agent(alpha=0.0005, gamma=0.99,n_actions=4)

env = gym.make('LunarLander-v2')
score_history = []

num_episodes = 2000

for i in range(num_episodes):
done = False
score = 0
observation = env.reset()
while not done:
action = agent.choose_action(observation)
observation_, reward, done, info = env.step(action)
agent.store_transition(observation, action, reward)
observation = observation_
score += reward
score_history.append(score)

agent.learn()
avg_score = np.mean(score_history[-100:])
print('episode: ', i,'score: %.1f' % score,
'average score %.1f' % avg_score)

filename = 'lunar-lander.png'
plotLearning(score_history, filename=filename, window=100)
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import tensorflow.keras as keras
from tensorflow.keras.layers import Dense

class PolicyGradientNetwork(keras.Model):
def __init__(self, n_actions, fc1_dims=256, fc2_dims=256):
super(PolicyGradientNetwork, self).__init__()
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.n_actions = n_actions

self.fc1 = Dense(self.fc1_dims, activation='relu')
self.fc2 = Dense(self.fc2_dims, activation='relu')
self.pi = Dense(n_actions, activation='softmax')

def call(self, state):
value = self.fc1(state)
value = self.fc2(value)

pi = self.pi(value)

return pi

Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import tensorflow as tf
from networks import PolicyGradientNetwork
import tensorflow_probability as tfp
from tensorflow.keras.optimizers import Adam
import numpy as np

class Agent:
def __init__(self, alpha=0.003, gamma=0.99, n_actions=4,
layer1_size=256, layer2_size=256):

self.gamma = gamma
self.lr = alpha
self.n_actions = n_actions
self.state_memory = []
self.action_memory = []
self.reward_memory = []
self.policy = PolicyGradientNetwork(n_actions=n_actions)
self.policy.compile(optimizer=Adam(learning_rate=self.lr))

def choose_action(self, observation):
state = tf.convert_to_tensor([observation], dtype=tf.float32)
probs = self.policy(state)
action_probs = tfp.distributions.Categorical(probs=probs)
action = action_probs.sample()

return action.numpy()[0]

def store_transition(self, observation, action, reward):
self.state_memory.append(observation)
self.action_memory.append(action)
self.reward_memory.append(reward)

def learn(self):
actions = tf.convert_to_tensor(self.action_memory, dtype=tf.float32)
rewards = np.array(self.reward_memory)

G = np.zeros_like(rewards)
for t in range(len(rewards)):
G_sum = 0
discount = 1
for k in range(t, len(rewards)):
G_sum += rewards[k] * discount
discount *= self.gamma
G[t] = G_sum

with tf.GradientTape() as tape:
loss = 0
for idx, (g, state) in enumerate(zip(G, self.state_memory)):
state = tf.convert_to_tensor([state], dtype=tf.float32)
probs = self.policy(state)
action_probs = tfp.distributions.Categorical(probs=probs)
log_prob = action_probs.log_prob(actions[idx])
loss += -g * tf.squeeze(log_prob)

gradient = tape.gradient(loss, self.policy.trainable_variables)
self.policy.optimizer.apply_gradients(zip(gradient, self.policy.trainable_variables))

self.state_memory = []
self.action_memory = []
self.reward_memory = []

0 comments on commit 986c658

Please sign in to comment.