Skip to content

Commit 4b31159

Browse files
committed
May 22 2018
1 parent ca19b66 commit 4b31159

File tree

15 files changed

+669
-936
lines changed

15 files changed

+669
-936
lines changed

OPENAI/Alg/a2c.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import os.path as osp
2+
import time
3+
import joblib
4+
import numpy as np
5+
import tensorflow as tf
6+
from baselines import logger
7+
8+
from baselines.common import set_global_seeds, explained_variance
9+
from baselines.common.runners import AbstractEnvRunner
10+
from baselines.common import tf_util
11+
12+
from baselines.a2c.utils import discount_with_dones
13+
from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables
14+
from baselines.a2c.utils import cat_entropy, mse
15+
16+
class Model(object):
17+
18+
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps,
19+
ent_coef=0.01, vf_coef=0.5, max_grad_norm=0.5, lr=7e-4,
20+
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
21+
22+
sess = tf_util.make_session()
23+
nbatch = nenvs*nsteps
24+
25+
A = tf.placeholder(tf.int32, [nbatch])
26+
ADV = tf.placeholder(tf.float32, [nbatch])
27+
R = tf.placeholder(tf.float32, [nbatch])
28+
LR = tf.placeholder(tf.float32, [])
29+
30+
step_model = policy(sess, ob_space, ac_space, nenvs, 1, reuse=False)
31+
train_model = policy(sess, ob_space, ac_space, nenvs*nsteps, nsteps, reuse=True)
32+
33+
neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.pi, labels=A)
34+
pg_loss = tf.reduce_mean(ADV * neglogpac)
35+
vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R))
36+
entropy = tf.reduce_mean(cat_entropy(train_model.pi))
37+
loss = pg_loss - entropy*ent_coef + vf_loss * vf_coef
38+
39+
params = find_trainable_variables("model")
40+
grads = tf.gradients(loss, params)
41+
if max_grad_norm is not None:
42+
grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
43+
grads = list(zip(grads, params))
44+
trainer = tf.train.RMSPropOptimizer(learning_rate=LR, decay=alpha, epsilon=epsilon)
45+
_train = trainer.apply_gradients(grads)
46+
47+
lr = Scheduler(v=lr, nvalues=total_timesteps, schedule=lrschedule)
48+
49+
def train(obs, states, rewards, masks, actions, values):
50+
advs = rewards - values
51+
for step in range(len(obs)):
52+
cur_lr = lr.value()
53+
td_map = {train_model.X:obs, A:actions, ADV:advs, R:rewards, LR:cur_lr}
54+
if states is not None:
55+
td_map[train_model.S] = states
56+
td_map[train_model.M] = masks
57+
policy_loss, value_loss, policy_entropy, _ = sess.run(
58+
[pg_loss, vf_loss, entropy, _train],
59+
td_map
60+
)
61+
return policy_loss, value_loss, policy_entropy
62+
63+
def save(save_path):
64+
ps = sess.run(params)
65+
make_path(osp.dirname(save_path))
66+
joblib.dump(ps, save_path)
67+
68+
def load(load_path):
69+
loaded_params = joblib.load(load_path)
70+
restores = []
71+
for p, loaded_p in zip(params, loaded_params):
72+
restores.append(p.assign(loaded_p))
73+
sess.run(restores)
74+
75+
self.train = train
76+
self.train_model = train_model
77+
self.step_model = step_model
78+
self.step = step_model.step
79+
self.value = step_model.value
80+
self.initial_state = step_model.initial_state
81+
self.save = save
82+
self.load = load
83+
tf.global_variables_initializer().run(session=sess)
84+
85+
class Runner(AbstractEnvRunner):
86+
87+
def __init__(self, env, model, nsteps=5, gamma=0.99):
88+
super().__init__(env=env, model=model, nsteps=nsteps)
89+
self.gamma = gamma
90+
91+
def run(self):
92+
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
93+
mb_states = self.states
94+
for n in range(self.nsteps):
95+
actions, values, states, _ = self.model.step(self.obs, self.states, self.dones)
96+
mb_obs.append(np.copy(self.obs))
97+
mb_actions.append(actions)
98+
mb_values.append(values)
99+
mb_dones.append(self.dones)
100+
obs, rewards, dones, _ = self.env.step(actions)
101+
# print(rewards.shape): (env_num,)
102+
self.states = states
103+
self.dones = dones
104+
for n, done in enumerate(dones):
105+
if done:
106+
self.obs[n] = self.obs[n]*0
107+
self.obs = obs
108+
mb_rewards.append(rewards)
109+
mb_dones.append(self.dones)
110+
#batch of steps to batch of rollouts
111+
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0).reshape(self.batch_ob_shape)
112+
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
113+
# print(mb_rewards.shape): [num_env,nsteps]
114+
avg_reward = mb_rewards.sum(axis=1).mean()
115+
max_reward = mb_rewards.sum(axis=1).max()
116+
min_reward = mb_rewards.sum(axis=1).min()
117+
118+
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
119+
mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0)
120+
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
121+
mb_masks = mb_dones[:, :-1]
122+
mb_dones = mb_dones[:, 1:]
123+
last_values = self.model.value(self.obs, self.states, self.dones).tolist()
124+
#discount/bootstrap off value fn
125+
for n, (rewards, dones, value) in enumerate(zip(mb_rewards, mb_dones, last_values)):
126+
rewards = rewards.tolist()
127+
dones = dones.tolist()
128+
if dones[-1] == 0:
129+
rewards = discount_with_dones(rewards+[value], dones+[0], self.gamma)[:-1]
130+
else:
131+
rewards = discount_with_dones(rewards, dones, self.gamma)
132+
# print(len(rewards)): nsteps, each entry is the accumulated discounted reward at time t
133+
mb_rewards[n] = rewards
134+
135+
# print(mb_rewards.shape): (nenv,ntimesteps), each env's accumulated discounter reward at each timestep
136+
avg_dis_reward = mb_rewards.mean(axis=0)[0]
137+
138+
mb_rewards = mb_rewards.flatten()
139+
# print(mb_rewards.shape): (nenv*ntimesteps,)
140+
mb_actions = mb_actions.flatten()
141+
mb_values = mb_values.flatten()
142+
mb_masks = mb_masks.flatten()
143+
144+
info = dict()
145+
info['avg_reward'] = avg_reward
146+
info['avg_dis_reward'] = avg_dis_reward
147+
info['max_reward'] = max_reward
148+
info['min_reward'] = min_reward
149+
return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values, info
150+
151+
def learn(policy, env, seed, nsteps=5, total_timesteps=int(80e6), vf_coef=0.5, ent_coef=0.01, max_grad_norm=0.5, lr=7e-4, lrschedule='linear', epsilon=1e-5, alpha=0.99, gamma=0.99,
152+
log_interval=1,
153+
save_interval=10,save_path="./a2c",load_path=None):
154+
tf.reset_default_graph()
155+
set_global_seeds(seed)
156+
157+
nenvs = env.num_envs
158+
ob_space = env.observation_space
159+
ac_space = env.action_space
160+
161+
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
162+
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
163+
164+
if load_path is not None:
165+
model.load(load_path)
166+
runner = Runner(env, model, nsteps=nsteps, gamma=gamma)
167+
168+
nbatch = nenvs*nsteps
169+
tstart = time.time()
170+
for update in range(1, total_timesteps//nbatch+1):
171+
obs, states, rewards, masks, actions, values, info = runner.run()
172+
policy_loss, value_loss, policy_entropy = model.train(obs, states, rewards, masks, actions, values)
173+
nseconds = time.time()-tstart
174+
fps = int((update*nbatch)/nseconds)
175+
if update % log_interval == 0 or update == 1:
176+
ev = explained_variance(values, rewards)
177+
logger.record_tabular("nupdates", update)
178+
logger.record_tabular("total_timesteps", update*nbatch)
179+
logger.record_tabular("fps", fps)
180+
logger.record_tabular("policy_entropy", float(policy_entropy))
181+
logger.record_tabular("value_loss", float(value_loss))
182+
logger.record_tabular("explained_variance", float(ev))
183+
for key in info.keys():
184+
logger.record_tabular(key,info[key])
185+
logger.dump_tabular()
186+
if update % save_interval == 0 or update == 1:
187+
model.save(save_path+"a2c_"+str(update)+".pkl")
188+
env.close()

OPENAI/Env/Atari/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from gym.envs.registration import register
2+
3+
register(
4+
# id='fooNoFrameskip-v0',
5+
id='carnivalRam20-v0',
6+
entry_point='Env.Atari.atari:AtariEnv',
7+
kwargs={'game': 'carnival', 'obs_type': 'ram', 'mask_num':20,'frameskip': 1},
8+
max_episode_steps=10000,
9+
nondeterministic=False,
10+
)

0 commit comments

Comments
 (0)