Skip to content

Commit eaf04ac

Browse files
committed
upload curiosity model
1 parent b9ee04b commit eaf04ac

File tree

3 files changed

+168
-0
lines changed

3 files changed

+168
-0
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ In these tutorials for reinforcement learning, it covers from the basic RL algor
3333
* [A3C](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/10_A3C)
3434
* [Dyna-Q](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/11_Dyna_Q)
3535
* [Proximal Policy Optimization (PPO)](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/12_Proximal_Policy_Optimization)
36+
* [Curiosity Model](/contents/Curiosity_Model)
3637
* [Some of my experiments](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/experiments)
3738
* [2D Car](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/experiments/2D_car)
3839
* [Robot arm](https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/experiments/Robot_arm)
@@ -82,6 +83,12 @@ In these tutorials for reinforcement learning, it covers from the basic RL algor
8283
<img class="course-image" src="https://morvanzhou.github.io/static/results/reinforcement-learning/6-4-3.png">
8384
</a>
8485

86+
### [Curiosity Model](/contents/Curiosity_Model)
87+
88+
<a href="/contents/Curiosity_Model">
89+
<img class="course-image" src="/contents/Curiosity_Model/Curiosity.png">
90+
</a>
91+
8592
# Donation
8693

8794
*If this does help you, please consider donating to support me for better tutorials. Any contribution is greatly appreciated!*
123 KB
Loading

contents/Curiosity_Model/Curiosity.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import gym
4+
import matplotlib.pyplot as plt
5+
6+
7+
class CuriosityNet:
8+
def __init__(
9+
self,
10+
n_a,
11+
n_s,
12+
lr=0.01,
13+
gamma=0.98,
14+
epsilon=0.95,
15+
replace_target_iter=300,
16+
memory_size=10000,
17+
batch_size=128,
18+
output_graph=False,
19+
):
20+
self.n_a = n_a
21+
self.n_s = n_s
22+
self.lr = lr
23+
self.gamma = gamma
24+
self.epsilon = epsilon
25+
self.replace_target_iter = replace_target_iter
26+
self.memory_size = memory_size
27+
self.batch_size = batch_size
28+
29+
# total learning step
30+
self.learn_step_counter = 0
31+
self.memory_counter = 0
32+
33+
# initialize zero memory [s, a, r, s_]
34+
self.memory = np.zeros((self.memory_size, n_s * 2 + 2))
35+
self.tfs, self.tfa, self.tfr, self.tfs_, self.dyn_train, self.dqn_train, self.q, self.int_r = \
36+
self._build_nets()
37+
38+
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net')
39+
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='eval_net')
40+
41+
with tf.variable_scope('hard_replacement'):
42+
self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
43+
44+
self.sess = tf.Session()
45+
46+
if output_graph:
47+
tf.summary.FileWriter("logs/", self.sess.graph)
48+
49+
self.sess.run(tf.global_variables_initializer())
50+
51+
def _build_nets(self):
52+
tfs = tf.placeholder(tf.float32, [None, self.n_s], name="s") # input State
53+
tfa = tf.placeholder(tf.int32, [None, ], name="a") # input Action
54+
tfr = tf.placeholder(tf.float32, [None, ], name="ext_r") # extrinsic reward
55+
tfs_ = tf.placeholder(tf.float32, [None, self.n_s], name="s_") # input Next State
56+
57+
# dynamics net
58+
dyn_s_, curiosity, dyn_train = self._build_dynamics_net(tfs, tfa, tfs_)
59+
60+
# normal RL model
61+
total_reward = tf.add(curiosity, tfr, name="total_r")
62+
q, dqn_loss, dqn_train = self._build_dqn(tfs, tfa, total_reward, tfs_)
63+
return tfs, tfa, tfr, tfs_, dyn_train, dqn_train, q, curiosity
64+
65+
def _build_dynamics_net(self, s, a, s_):
66+
with tf.variable_scope("dyn_net"):
67+
float_a = tf.expand_dims(tf.cast(a, dtype=tf.float32, name="float_a"), axis=1, name="2d_a")
68+
sa = tf.concat((s, float_a), axis=1, name="sa")
69+
encoded_s_ = s_ # here we use s_ as the encoded s_
70+
71+
dyn_l = tf.layers.dense(sa, 32, activation=tf.nn.relu)
72+
dyn_s_ = tf.layers.dense(dyn_l, self.n_s) # predicted s_
73+
with tf.name_scope("int_r"):
74+
squared_diff = tf.reduce_sum(tf.square(encoded_s_ - dyn_s_), axis=1) # intrinsic reward
75+
76+
# It is better to reduce the learning rate in order to stay curious
77+
train_op = tf.train.RMSPropOptimizer(self.lr, name="dyn_opt").minimize(squared_diff)
78+
return dyn_s_, squared_diff, train_op
79+
80+
def _build_dqn(self, s, a, r, s_):
81+
with tf.variable_scope('eval_net'):
82+
e1 = tf.layers.dense(s, 128, tf.nn.relu)
83+
q = tf.layers.dense(e1, self.n_a, name="q")
84+
with tf.variable_scope('target_net'):
85+
t1 = tf.layers.dense(s_, 128, tf.nn.relu)
86+
q_ = tf.layers.dense(t1, self.n_a, name="q_")
87+
88+
with tf.variable_scope('q_target'):
89+
q_target = r + self.gamma * tf.reduce_max(q_, axis=1, name="Qmax_s_")
90+
91+
with tf.variable_scope('q_wrt_a'):
92+
a_indices = tf.stack([tf.range(tf.shape(a)[0], dtype=tf.int32), a], axis=1)
93+
q_wrt_a = tf.gather_nd(params=q, indices=a_indices)
94+
95+
loss = tf.losses.mean_squared_error(labels=q_target, predictions=q_wrt_a) # TD error
96+
train_op = tf.train.RMSPropOptimizer(self.lr, name="dqn_opt").minimize(
97+
loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "eval_net"))
98+
return q, loss, train_op
99+
100+
def store_transition(self, s, a, r, s_):
101+
transition = np.hstack((s, [a, r], s_))
102+
# replace the old memory with new memory
103+
index = self.memory_counter % self.memory_size
104+
self.memory[index, :] = transition
105+
self.memory_counter += 1
106+
107+
def choose_action(self, observation):
108+
# to have batch dimension when feed into tf placeholder
109+
s = observation[np.newaxis, :]
110+
111+
if np.random.uniform() < self.epsilon:
112+
# forward feed the observation and get q value for every actions
113+
actions_value = self.sess.run(self.q, feed_dict={self.tfs: s})
114+
action = np.argmax(actions_value)
115+
else:
116+
action = np.random.randint(0, self.n_a)
117+
return action
118+
119+
def learn(self):
120+
# check to replace target parameters
121+
if self.learn_step_counter % self.replace_target_iter == 0:
122+
self.sess.run(self.target_replace_op)
123+
124+
# sample batch memory from all memory
125+
top = self.memory_size if self.memory_counter > self.memory_size else self.memory_counter
126+
sample_index = np.random.choice(top, size=self.batch_size)
127+
batch_memory = self.memory[sample_index, :]
128+
129+
bs, ba, br, bs_ = batch_memory[:, :self.n_s], batch_memory[:, self.n_s], \
130+
batch_memory[:, self.n_s + 1], batch_memory[:, -self.n_s:]
131+
self.sess.run(self.dqn_train, feed_dict={self.tfs: bs, self.tfa: ba, self.tfr: br, self.tfs_: bs_})
132+
if self.learn_step_counter % 1000 == 0:
133+
self.sess.run(self.dyn_train, feed_dict={self.tfs: bs, self.tfa: ba, self.tfs_: bs_})
134+
self.learn_step_counter += 1
135+
136+
137+
env = gym.make('MountainCar-v0')
138+
env = env.unwrapped
139+
140+
dqn = CuriosityNet(n_a=3, n_s=2, lr=0.01, output_graph=False)
141+
ep_steps = []
142+
for epi in range(200):
143+
s = env.reset()
144+
steps = 0
145+
while True:
146+
env.render()
147+
a = dqn.choose_action(s)
148+
s_, r, done, info = env.step(a)
149+
dqn.store_transition(s, a, r, s_)
150+
dqn.learn()
151+
if done:
152+
print('Epi: ', epi, "| steps: ", steps)
153+
ep_steps.append(steps)
154+
break
155+
s = s_
156+
steps += 1
157+
158+
plt.plot(ep_steps)
159+
plt.ylabel("steps")
160+
plt.xlabel("episode")
161+
plt.show()

0 commit comments

Comments
 (0)