1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import gym
4
+ import matplotlib .pyplot as plt
5
+
6
+
7
+ class CuriosityNet :
8
+ def __init__ (
9
+ self ,
10
+ n_a ,
11
+ n_s ,
12
+ lr = 0.01 ,
13
+ gamma = 0.98 ,
14
+ epsilon = 0.95 ,
15
+ replace_target_iter = 300 ,
16
+ memory_size = 10000 ,
17
+ batch_size = 128 ,
18
+ output_graph = False ,
19
+ ):
20
+ self .n_a = n_a
21
+ self .n_s = n_s
22
+ self .lr = lr
23
+ self .gamma = gamma
24
+ self .epsilon = epsilon
25
+ self .replace_target_iter = replace_target_iter
26
+ self .memory_size = memory_size
27
+ self .batch_size = batch_size
28
+
29
+ # total learning step
30
+ self .learn_step_counter = 0
31
+ self .memory_counter = 0
32
+
33
+ # initialize zero memory [s, a, r, s_]
34
+ self .memory = np .zeros ((self .memory_size , n_s * 2 + 2 ))
35
+ self .tfs , self .tfa , self .tfr , self .tfs_ , self .dyn_train , self .dqn_train , self .q , self .int_r = \
36
+ self ._build_nets ()
37
+
38
+ t_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'target_net' )
39
+ e_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'eval_net' )
40
+
41
+ with tf .variable_scope ('hard_replacement' ):
42
+ self .target_replace_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
43
+
44
+ self .sess = tf .Session ()
45
+
46
+ if output_graph :
47
+ tf .summary .FileWriter ("logs/" , self .sess .graph )
48
+
49
+ self .sess .run (tf .global_variables_initializer ())
50
+
51
+ def _build_nets (self ):
52
+ tfs = tf .placeholder (tf .float32 , [None , self .n_s ], name = "s" ) # input State
53
+ tfa = tf .placeholder (tf .int32 , [None , ], name = "a" ) # input Action
54
+ tfr = tf .placeholder (tf .float32 , [None , ], name = "ext_r" ) # extrinsic reward
55
+ tfs_ = tf .placeholder (tf .float32 , [None , self .n_s ], name = "s_" ) # input Next State
56
+
57
+ # dynamics net
58
+ dyn_s_ , curiosity , dyn_train = self ._build_dynamics_net (tfs , tfa , tfs_ )
59
+
60
+ # normal RL model
61
+ total_reward = tf .add (curiosity , tfr , name = "total_r" )
62
+ q , dqn_loss , dqn_train = self ._build_dqn (tfs , tfa , total_reward , tfs_ )
63
+ return tfs , tfa , tfr , tfs_ , dyn_train , dqn_train , q , curiosity
64
+
65
+ def _build_dynamics_net (self , s , a , s_ ):
66
+ with tf .variable_scope ("dyn_net" ):
67
+ float_a = tf .expand_dims (tf .cast (a , dtype = tf .float32 , name = "float_a" ), axis = 1 , name = "2d_a" )
68
+ sa = tf .concat ((s , float_a ), axis = 1 , name = "sa" )
69
+ encoded_s_ = s_ # here we use s_ as the encoded s_
70
+
71
+ dyn_l = tf .layers .dense (sa , 32 , activation = tf .nn .relu )
72
+ dyn_s_ = tf .layers .dense (dyn_l , self .n_s ) # predicted s_
73
+ with tf .name_scope ("int_r" ):
74
+ squared_diff = tf .reduce_sum (tf .square (encoded_s_ - dyn_s_ ), axis = 1 ) # intrinsic reward
75
+
76
+ # It is better to reduce the learning rate in order to stay curious
77
+ train_op = tf .train .RMSPropOptimizer (self .lr , name = "dyn_opt" ).minimize (squared_diff )
78
+ return dyn_s_ , squared_diff , train_op
79
+
80
+ def _build_dqn (self , s , a , r , s_ ):
81
+ with tf .variable_scope ('eval_net' ):
82
+ e1 = tf .layers .dense (s , 128 , tf .nn .relu )
83
+ q = tf .layers .dense (e1 , self .n_a , name = "q" )
84
+ with tf .variable_scope ('target_net' ):
85
+ t1 = tf .layers .dense (s_ , 128 , tf .nn .relu )
86
+ q_ = tf .layers .dense (t1 , self .n_a , name = "q_" )
87
+
88
+ with tf .variable_scope ('q_target' ):
89
+ q_target = r + self .gamma * tf .reduce_max (q_ , axis = 1 , name = "Qmax_s_" )
90
+
91
+ with tf .variable_scope ('q_wrt_a' ):
92
+ a_indices = tf .stack ([tf .range (tf .shape (a )[0 ], dtype = tf .int32 ), a ], axis = 1 )
93
+ q_wrt_a = tf .gather_nd (params = q , indices = a_indices )
94
+
95
+ loss = tf .losses .mean_squared_error (labels = q_target , predictions = q_wrt_a ) # TD error
96
+ train_op = tf .train .RMSPropOptimizer (self .lr , name = "dqn_opt" ).minimize (
97
+ loss , var_list = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , "eval_net" ))
98
+ return q , loss , train_op
99
+
100
+ def store_transition (self , s , a , r , s_ ):
101
+ transition = np .hstack ((s , [a , r ], s_ ))
102
+ # replace the old memory with new memory
103
+ index = self .memory_counter % self .memory_size
104
+ self .memory [index , :] = transition
105
+ self .memory_counter += 1
106
+
107
+ def choose_action (self , observation ):
108
+ # to have batch dimension when feed into tf placeholder
109
+ s = observation [np .newaxis , :]
110
+
111
+ if np .random .uniform () < self .epsilon :
112
+ # forward feed the observation and get q value for every actions
113
+ actions_value = self .sess .run (self .q , feed_dict = {self .tfs : s })
114
+ action = np .argmax (actions_value )
115
+ else :
116
+ action = np .random .randint (0 , self .n_a )
117
+ return action
118
+
119
+ def learn (self ):
120
+ # check to replace target parameters
121
+ if self .learn_step_counter % self .replace_target_iter == 0 :
122
+ self .sess .run (self .target_replace_op )
123
+
124
+ # sample batch memory from all memory
125
+ top = self .memory_size if self .memory_counter > self .memory_size else self .memory_counter
126
+ sample_index = np .random .choice (top , size = self .batch_size )
127
+ batch_memory = self .memory [sample_index , :]
128
+
129
+ bs , ba , br , bs_ = batch_memory [:, :self .n_s ], batch_memory [:, self .n_s ], \
130
+ batch_memory [:, self .n_s + 1 ], batch_memory [:, - self .n_s :]
131
+ self .sess .run (self .dqn_train , feed_dict = {self .tfs : bs , self .tfa : ba , self .tfr : br , self .tfs_ : bs_ })
132
+ if self .learn_step_counter % 1000 == 0 :
133
+ self .sess .run (self .dyn_train , feed_dict = {self .tfs : bs , self .tfa : ba , self .tfs_ : bs_ })
134
+ self .learn_step_counter += 1
135
+
136
+
137
+ env = gym .make ('MountainCar-v0' )
138
+ env = env .unwrapped
139
+
140
+ dqn = CuriosityNet (n_a = 3 , n_s = 2 , lr = 0.01 , output_graph = False )
141
+ ep_steps = []
142
+ for epi in range (200 ):
143
+ s = env .reset ()
144
+ steps = 0
145
+ while True :
146
+ env .render ()
147
+ a = dqn .choose_action (s )
148
+ s_ , r , done , info = env .step (a )
149
+ dqn .store_transition (s , a , r , s_ )
150
+ dqn .learn ()
151
+ if done :
152
+ print ('Epi: ' , epi , "| steps: " , steps )
153
+ ep_steps .append (steps )
154
+ break
155
+ s = s_
156
+ steps += 1
157
+
158
+ plt .plot (ep_steps )
159
+ plt .ylabel ("steps" )
160
+ plt .xlabel ("episode" )
161
+ plt .show ()
0 commit comments