1
+ """This is a simple implementation of [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894)"""
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import gym
6
+ import matplotlib .pyplot as plt
7
+
8
+
9
+ class CuriosityNet :
10
+ def __init__ (
11
+ self ,
12
+ n_a ,
13
+ n_s ,
14
+ lr = 0.01 ,
15
+ gamma = 0.95 ,
16
+ epsilon = 1. ,
17
+ replace_target_iter = 300 ,
18
+ memory_size = 10000 ,
19
+ batch_size = 128 ,
20
+ output_graph = False ,
21
+ ):
22
+ self .n_a = n_a
23
+ self .n_s = n_s
24
+ self .lr = lr
25
+ self .gamma = gamma
26
+ self .epsilon = epsilon
27
+ self .replace_target_iter = replace_target_iter
28
+ self .memory_size = memory_size
29
+ self .batch_size = batch_size
30
+ self .s_encode_size = 1000 # give a hard job for predictor to learn
31
+
32
+ # total learning step
33
+ self .learn_step_counter = 0
34
+ self .memory_counter = 0
35
+
36
+ # initialize zero memory [s, a, r, s_]
37
+ self .memory = np .zeros ((self .memory_size , n_s * 2 + 2 ))
38
+ self .tfs , self .tfa , self .tfr , self .tfs_ , self .pred_train , self .dqn_train , self .q = \
39
+ self ._build_nets ()
40
+
41
+ t_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'target_net' )
42
+ e_params = tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES , scope = 'eval_net' )
43
+
44
+ with tf .variable_scope ('hard_replacement' ):
45
+ self .target_replace_op = [tf .assign (t , e ) for t , e in zip (t_params , e_params )]
46
+
47
+ self .sess = tf .Session ()
48
+
49
+ if output_graph :
50
+ tf .summary .FileWriter ("logs/" , self .sess .graph )
51
+
52
+ self .sess .run (tf .global_variables_initializer ())
53
+
54
+ def _build_nets (self ):
55
+ tfs = tf .placeholder (tf .float32 , [None , self .n_s ], name = "s" ) # input State
56
+ tfa = tf .placeholder (tf .int32 , [None , ], name = "a" ) # input Action
57
+ tfr = tf .placeholder (tf .float32 , [None , ], name = "ext_r" ) # extrinsic reward
58
+ tfs_ = tf .placeholder (tf .float32 , [None , self .n_s ], name = "s_" ) # input Next State
59
+
60
+ # fixed random net
61
+ with tf .variable_scope ("random_net" ):
62
+ rand_encode_s_ = tf .layers .dense (tfs_ , self .s_encode_size )
63
+
64
+ # predictor
65
+ ri , pred_train = self ._build_predictor (tfs_ , rand_encode_s_ )
66
+
67
+ # normal RL model
68
+ q , dqn_loss , dqn_train = self ._build_dqn (tfs , tfa , ri , tfr , tfs_ )
69
+ return tfs , tfa , tfr , tfs_ , pred_train , dqn_train , q
70
+
71
+ def _build_predictor (self , s_ , rand_encode_s_ ):
72
+ with tf .variable_scope ("predictor" ):
73
+ net = tf .layers .dense (s_ , 128 , tf .nn .relu )
74
+ out = tf .layers .dense (net , self .s_encode_size )
75
+
76
+ with tf .name_scope ("int_r" ):
77
+ ri = squared_diff = tf .reduce_sum (tf .square (rand_encode_s_ - out ), axis = 1 ) # intrinsic reward
78
+ train_op = tf .train .RMSPropOptimizer (self .lr , name = "predictor_opt" ).minimize (
79
+ squared_diff , var_list = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , "predictor" ))
80
+
81
+ return ri , train_op
82
+
83
+ def _build_dqn (self , s , a , ri , re , s_ ):
84
+ with tf .variable_scope ('eval_net' ):
85
+ e1 = tf .layers .dense (s , 128 , tf .nn .relu )
86
+ q = tf .layers .dense (e1 , self .n_a , name = "q" )
87
+ with tf .variable_scope ('target_net' ):
88
+ t1 = tf .layers .dense (s_ , 128 , tf .nn .relu )
89
+ q_ = tf .layers .dense (t1 , self .n_a , name = "q_" )
90
+
91
+ with tf .variable_scope ('q_target' ):
92
+ q_target = re + ri + self .gamma * tf .reduce_max (q_ , axis = 1 , name = "Qmax_s_" )
93
+
94
+ with tf .variable_scope ('q_wrt_a' ):
95
+ a_indices = tf .stack ([tf .range (tf .shape (a )[0 ], dtype = tf .int32 ), a ], axis = 1 )
96
+ q_wrt_a = tf .gather_nd (params = q , indices = a_indices )
97
+
98
+ loss = tf .losses .mean_squared_error (labels = q_target , predictions = q_wrt_a ) # TD error
99
+ train_op = tf .train .RMSPropOptimizer (self .lr , name = "dqn_opt" ).minimize (
100
+ loss , var_list = tf .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES , "eval_net" ))
101
+ return q , loss , train_op
102
+
103
+ def store_transition (self , s , a , r , s_ ):
104
+ transition = np .hstack ((s , [a , r ], s_ ))
105
+ # replace the old memory with new memory
106
+ index = self .memory_counter % self .memory_size
107
+ self .memory [index , :] = transition
108
+ self .memory_counter += 1
109
+
110
+ def choose_action (self , observation ):
111
+ # to have batch dimension when feed into tf placeholder
112
+ s = observation [np .newaxis , :]
113
+
114
+ if np .random .uniform () < self .epsilon :
115
+ # forward feed the observation and get q value for every actions
116
+ actions_value = self .sess .run (self .q , feed_dict = {self .tfs : s })
117
+ action = np .argmax (actions_value )
118
+ else :
119
+ action = np .random .randint (0 , self .n_a )
120
+ return action
121
+
122
+ def learn (self ):
123
+ # check to replace target parameters
124
+ if self .learn_step_counter % self .replace_target_iter == 0 :
125
+ self .sess .run (self .target_replace_op )
126
+
127
+ # sample batch memory from all memory
128
+ top = self .memory_size if self .memory_counter > self .memory_size else self .memory_counter
129
+ sample_index = np .random .choice (top , size = self .batch_size )
130
+ batch_memory = self .memory [sample_index , :]
131
+
132
+ bs , ba , br , bs_ = batch_memory [:, :self .n_s ], batch_memory [:, self .n_s ], \
133
+ batch_memory [:, self .n_s + 1 ], batch_memory [:, - self .n_s :]
134
+ self .sess .run (self .dqn_train , feed_dict = {self .tfs : bs , self .tfa : ba , self .tfr : br , self .tfs_ : bs_ })
135
+ if self .learn_step_counter % 100 == 0 : # delay training in order to stay curious
136
+ self .sess .run (self .pred_train , feed_dict = {self .tfs_ : bs_ })
137
+ self .learn_step_counter += 1
138
+
139
+
140
+ env = gym .make ('MountainCar-v0' )
141
+ env = env .unwrapped
142
+
143
+ dqn = CuriosityNet (n_a = 3 , n_s = 2 , lr = 0.01 , output_graph = False )
144
+ ep_steps = []
145
+ for epi in range (200 ):
146
+ s = env .reset ()
147
+ steps = 0
148
+ while True :
149
+ # env.render()
150
+ a = dqn .choose_action (s )
151
+ s_ , r , done , info = env .step (a )
152
+ dqn .store_transition (s , a , r , s_ )
153
+ dqn .learn ()
154
+ if done :
155
+ print ('Epi: ' , epi , "| steps: " , steps )
156
+ ep_steps .append (steps )
157
+ break
158
+ s = s_
159
+ steps += 1
160
+
161
+ plt .plot (ep_steps )
162
+ plt .ylabel ("steps" )
163
+ plt .xlabel ("episode" )
164
+ plt .show ()
0 commit comments