Skip to content

Commit 3d1fb65

Browse files
committed
add RND
1 parent fe91ee4 commit 3d1fb65

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

contents/Curiosity_Model/Curiosity.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""This is a simple implementation of [Large-Scale Study of Curiosity-Driven Learning](https://arxiv.org/abs/1808.04355)"""
2+
13
import numpy as np
24
import tensorflow as tf
35
import gym
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""This is a simple implementation of [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894)"""
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
import gym
6+
import matplotlib.pyplot as plt
7+
8+
9+
class CuriosityNet:
10+
def __init__(
11+
self,
12+
n_a,
13+
n_s,
14+
lr=0.01,
15+
gamma=0.95,
16+
epsilon=1.,
17+
replace_target_iter=300,
18+
memory_size=10000,
19+
batch_size=128,
20+
output_graph=False,
21+
):
22+
self.n_a = n_a
23+
self.n_s = n_s
24+
self.lr = lr
25+
self.gamma = gamma
26+
self.epsilon = epsilon
27+
self.replace_target_iter = replace_target_iter
28+
self.memory_size = memory_size
29+
self.batch_size = batch_size
30+
self.s_encode_size = 1000 # give a hard job for predictor to learn
31+
32+
# total learning step
33+
self.learn_step_counter = 0
34+
self.memory_counter = 0
35+
36+
# initialize zero memory [s, a, r, s_]
37+
self.memory = np.zeros((self.memory_size, n_s * 2 + 2))
38+
self.tfs, self.tfa, self.tfr, self.tfs_, self.pred_train, self.dqn_train, self.q = \
39+
self._build_nets()
40+
41+
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
42+
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
43+
44+
with tf.variable_scope('hard_replacement'):
45+
self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
46+
47+
self.sess = tf.Session()
48+
49+
if output_graph:
50+
tf.summary.FileWriter("logs/", self.sess.graph)
51+
52+
self.sess.run(tf.global_variables_initializer())
53+
54+
def _build_nets(self):
55+
tfs = tf.placeholder(tf.float32, [None, self.n_s], name="s") # input State
56+
tfa = tf.placeholder(tf.int32, [None, ], name="a") # input Action
57+
tfr = tf.placeholder(tf.float32, [None, ], name="ext_r") # extrinsic reward
58+
tfs_ = tf.placeholder(tf.float32, [None, self.n_s], name="s_") # input Next State
59+
60+
# fixed random net
61+
with tf.variable_scope("random_net"):
62+
rand_encode_s_ = tf.layers.dense(tfs_, self.s_encode_size)
63+
64+
# predictor
65+
ri, pred_train = self._build_predictor(tfs_, rand_encode_s_)
66+
67+
# normal RL model
68+
q, dqn_loss, dqn_train = self._build_dqn(tfs, tfa, ri, tfr, tfs_)
69+
return tfs, tfa, tfr, tfs_, pred_train, dqn_train, q
70+
71+
def _build_predictor(self, s_, rand_encode_s_):
72+
with tf.variable_scope("predictor"):
73+
net = tf.layers.dense(s_, 128, tf.nn.relu)
74+
out = tf.layers.dense(net, self.s_encode_size)
75+
76+
with tf.name_scope("int_r"):
77+
ri = squared_diff = tf.reduce_sum(tf.square(rand_encode_s_ - out), axis=1) # intrinsic reward
78+
train_op = tf.train.RMSPropOptimizer(self.lr, name="predictor_opt").minimize(
79+
squared_diff, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "predictor"))
80+
81+
return ri, train_op
82+
83+
def _build_dqn(self, s, a, ri, re, s_):
84+
with tf.variable_scope('eval_net'):
85+
e1 = tf.layers.dense(s, 128, tf.nn.relu)
86+
q = tf.layers.dense(e1, self.n_a, name="q")
87+
with tf.variable_scope('target_net'):
88+
t1 = tf.layers.dense(s_, 128, tf.nn.relu)
89+
q_ = tf.layers.dense(t1, self.n_a, name="q_")
90+
91+
with tf.variable_scope('q_target'):
92+
q_target = re + ri + self.gamma * tf.reduce_max(q_, axis=1, name="Qmax_s_")
93+
94+
with tf.variable_scope('q_wrt_a'):
95+
a_indices = tf.stack([tf.range(tf.shape(a)[0], dtype=tf.int32), a], axis=1)
96+
q_wrt_a = tf.gather_nd(params=q, indices=a_indices)
97+
98+
loss = tf.losses.mean_squared_error(labels=q_target, predictions=q_wrt_a) # TD error
99+
train_op = tf.train.RMSPropOptimizer(self.lr, name="dqn_opt").minimize(
100+
loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "eval_net"))
101+
return q, loss, train_op
102+
103+
def store_transition(self, s, a, r, s_):
104+
transition = np.hstack((s, [a, r], s_))
105+
# replace the old memory with new memory
106+
index = self.memory_counter % self.memory_size
107+
self.memory[index, :] = transition
108+
self.memory_counter += 1
109+
110+
def choose_action(self, observation):
111+
# to have batch dimension when feed into tf placeholder
112+
s = observation[np.newaxis, :]
113+
114+
if np.random.uniform() < self.epsilon:
115+
# forward feed the observation and get q value for every actions
116+
actions_value = self.sess.run(self.q, feed_dict={self.tfs: s})
117+
action = np.argmax(actions_value)
118+
else:
119+
action = np.random.randint(0, self.n_a)
120+
return action
121+
122+
def learn(self):
123+
# check to replace target parameters
124+
if self.learn_step_counter % self.replace_target_iter == 0:
125+
self.sess.run(self.target_replace_op)
126+
127+
# sample batch memory from all memory
128+
top = self.memory_size if self.memory_counter > self.memory_size else self.memory_counter
129+
sample_index = np.random.choice(top, size=self.batch_size)
130+
batch_memory = self.memory[sample_index, :]
131+
132+
bs, ba, br, bs_ = batch_memory[:, :self.n_s], batch_memory[:, self.n_s], \
133+
batch_memory[:, self.n_s + 1], batch_memory[:, -self.n_s:]
134+
self.sess.run(self.dqn_train, feed_dict={self.tfs: bs, self.tfa: ba, self.tfr: br, self.tfs_: bs_})
135+
if self.learn_step_counter % 100 == 0: # delay training in order to stay curious
136+
self.sess.run(self.pred_train, feed_dict={self.tfs_: bs_})
137+
self.learn_step_counter += 1
138+
139+
140+
env = gym.make('MountainCar-v0')
141+
env = env.unwrapped
142+
143+
dqn = CuriosityNet(n_a=3, n_s=2, lr=0.01, output_graph=False)
144+
ep_steps = []
145+
for epi in range(200):
146+
s = env.reset()
147+
steps = 0
148+
while True:
149+
# env.render()
150+
a = dqn.choose_action(s)
151+
s_, r, done, info = env.step(a)
152+
dqn.store_transition(s, a, r, s_)
153+
dqn.learn()
154+
if done:
155+
print('Epi: ', epi, "| steps: ", steps)
156+
ep_steps.append(steps)
157+
break
158+
s = s_
159+
steps += 1
160+
161+
plt.plot(ep_steps)
162+
plt.ylabel("steps")
163+
plt.xlabel("episode")
164+
plt.show()

0 commit comments

Comments
 (0)