Skip to content

Commit

Permalink
add feature extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
phuocphn committed Sep 26, 2019
1 parent 8afb45f commit 434f053
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 10 deletions.
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import tensorflow as tf

class A3C(object):
def __init__(self, env, worker_task_index, sess=None):
# we will definite network and all necessary operations in here.

# define target network in parameter server (`target (global) network weights` and `global step`)
with tf.device(device_name_or_function=tf.train.replica_device_setter(
ps_tasks=1, ps_device="/job:ps",
worker_device="/job:worker/task:{}/cpu:0".format(worker_task_index))):
with tf.variable_scope("global", reuse=None):
self.global_network = CNNLSTMPolicy(state_shape = env.observation_space.shape, num_action=env.action_space.n)
self.global_step = tf.get_variable(name="global_step",
shape=[],
dtype=tf.int32,
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
#self.global_action_network = StateActionPredictor(state_shape = env.observation_space.shape, num_action=env.action_space.n)


# define local network in local worker (`local network weights` and `local step`)
with tf.device(device_name_or_function="/job:worker/task:{}/cpu:0".format(worker_task_index)):
with tf.variable_scope("local", reuse=None):
self.local_network = CNNLSTMPolicy(state_shape = env.observation_space.shape, num_action=env.action_space.n)
self.local_step = self.global_step
#self.local_action_network = StateActionPredictor(state_shape = env.observation_space.shape, num_action=env.action_space.n)

def train(self, sess):
pass
49 changes: 49 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tensorflow as tf
import numpy as np

# Create some wrappers for simplicity
def conv2d(name, input, strides=[2, 2], filter_size=[3, 3], num_filters=32):
# Conv2D wrapper, with bias and relu activation
with tf.variable_scope(name):
W = tf.get_variable("W", shape=[filter_size[0], filter_size[1], int(input.shape()[3]), num_filters],
dtype=tf.float32,
# https://medium.com/@prateekvishnu/xavier-and-he-normal-he-et-al-initialization-8e3d7a087528
initializer=tf.contrib.layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN',
uniform=True, seed=None,
dtype=tf.float32))
b = tf.get_variable("b", shape=[1, 1, 1, 32], dtype=tf.float32,
initializer=tf.constant_initializer(0.0))
return tf.nn.elu(
tf.nn.bias_add(tf.nn.conv2d(input, filter=W, strides=[1, strides[0], strides[1], 1], padding="SAME"), b))


def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)

return _initializer

class CNNLSTMPolicy(object):
"""
Feature extractor: [None, num_features ] ~~> [None, 256]
"""
def __init__(self, state_shape, num_action):
"""
:param state_shape:
:param num_action:
"""
self.input = tf.placeholder(dtype=tf.float32, shape=[None] + state_shape, name="input")

# 4 convolution layer stack together.
__input = self.input
for idx in range(4):
__input = conv2d(name=f"layer_{idx}", input=__input, strides=[2,2], filter_size=[3,3], num_filters=32)
self.output = tf.reshape(__input, [-1, np.prod(__input.get_shape().as_list()[1:])])

# Last fully connected layer (value function).
w = tf.get_variable("value_function/w", [__input.get_shape()[1], num_action], initializer=normalized_columns_initializer(1.0))
b = tf.get_variable("value_function/b", [num_action], initializer=tf.constant_initializer(0.0))
self.logits = tf.matmul(__input, w) + b
14 changes: 4 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import tensorflow as tf
import argparse
import gym
from somewhere import DoomGame, A3CICM

# Get user-provided parameters from args
parser = argparse.ArgumentParser()
Expand All @@ -16,10 +14,7 @@
LOG_DIR = '/tmp/doom'
ENV_ID = 'doom'
NUM_WORKERS = 20
PS_PORT = 12222
TASK = 2
DISCARD_LIVING_REWARD = True # Remove all negative reward (in doom: it is living reward)
FEATURE_EXTRACTOR = 'universe'
TOTAL_TRAINING_STEP = 100 # this is total step and is used for all workers.

cluster = tf.train.ClusterSpec({"ps": "localhost:12200", "worker": ["localhost:12300", "localhost:12301"]})
Expand All @@ -33,11 +28,9 @@
server = tf.train.Server(server_or_cluster_def=cluster, job_name=JOB_NAME, task_index=TASK_INDEX,
config=tf.ConfigProto(intra_op_parellelism=1, inter_op_parallelism_threads=2))

env = gym.make_env(env_id=ENV_ID)
env = None
trainer = A3CICM(env=env,
worker_task_index=TASK_INDEX, #specify which machine (worker) will be used to train agent.
feature_extractor=FEATURE_EXTRACTOR,
discard_living_reward=DISCARD_LIVING_REWARD)
worker_task_index=TASK_INDEX) #specify which machine (worker) will be used to train agent.

# Variables with the name starts with `local...` will not be saved in the checkpoints
# only save target network and related variables.
Expand All @@ -57,12 +50,13 @@
config=tf.ConfigProto(device_filters=["/job:ps", f"/job:worker/task:{TASK_INDEX}/cpu:0"])) as sess:

saver.restore(sess=sess, save_path=tf.train.latest_checkpoint(PRETRAIN_MODEL_PATH))
sess.run(tf.global_variables_initializer())
sess.run(trainer.sync_weight_from_target_network)
current_training_step = sess.run(trainer.training_step) # training_step is put in parameter server.
print (f"Worker: {JOB_NAME + ':'+ TASK_INDEX } in training step: {current_training_step}")

while not supervisor.should_stop() and current_training_step < TOTAL_TRAINING_STEP:
trainer.train(session=sess)
trainer.train(sess=sess)
current_training_step = sess.run(trainer.training_step)

# Ask for all the services to stop.
Expand Down

0 comments on commit 434f053

Please sign in to comment.