@@ -217,7 +217,7 @@ def forward(self, X):
217
217
return self .f (a )
218
218
219
219
class DQN :
220
- def __init__ (self , K , conv_layer_sizes , hidden_layer_sizes , gamma ):
220
+ def __init__ (self , K , conv_layer_sizes , hidden_layer_sizes ):
221
221
self .K = K
222
222
223
223
# inputs and targets
@@ -253,7 +253,7 @@ def __init__(self, K, conv_layer_sizes, hidden_layer_sizes, gamma):
253
253
# build fully connected layers
254
254
self .layers = []
255
255
M1 = flattened_ouput_size
256
- # print("flattened_ouput_size:", flattened_ouput_size)
256
+ print ("flattened_ouput_size:" , flattened_ouput_size )
257
257
for M2 in hidden_layer_sizes :
258
258
layer = HiddenLayer (M1 , M2 )
259
259
self .layers .append (layer )
@@ -284,6 +284,7 @@ def __init__(self, K, conv_layer_sizes, hidden_layer_sizes, gamma):
284
284
# compile functions
285
285
self .train_op = theano .function (
286
286
inputs = [X , G , actions ],
287
+ outputs = cost ,
287
288
updates = updates ,
288
289
allow_input_downcast = True
289
290
)
@@ -305,7 +306,7 @@ def predict(self, X):
305
306
return self .predict_op (X )
306
307
307
308
def update (self , states , actions , targets ):
308
- self .train_op (states , targets , actions )
309
+ return self .train_op (states , targets , actions )
309
310
310
311
def sample_action (self , x , eps ):
311
312
if np .random .random () < eps :
@@ -434,13 +435,11 @@ def smooth(x):
434
435
K = K ,
435
436
conv_layer_sizes = conv_layer_sizes ,
436
437
hidden_layer_sizes = hidden_layer_sizes ,
437
- gamma = gamma ,
438
438
)
439
439
target_model = DQN (
440
440
K = K ,
441
441
conv_layer_sizes = conv_layer_sizes ,
442
442
hidden_layer_sizes = hidden_layer_sizes ,
443
- gamma = gamma ,
444
443
)
445
444
446
445
@@ -451,6 +450,7 @@ def smooth(x):
451
450
452
451
action = np .random .choice (K )
453
452
obs , reward , done , _ = env .step (action )
453
+ obs_small = downsample_image (obs )
454
454
experience_replay_buffer .add_experience (action , obs_small , reward , done )
455
455
456
456
if done :
0 commit comments