Skip to content

Commit de9fe54

Browse files
Bug correction replay
1 parent 87927cc commit de9fe54

File tree

1 file changed

+64
-22
lines changed

1 file changed

+64
-22
lines changed

Deep Q Learning/Doom/Deep Q learning with Doom.ipynb

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@
102102
{
103103
"cell_type": "code",
104104
"execution_count": 1,
105-
"metadata": {},
105+
"metadata": {
106+
"collapsed": true
107+
},
106108
"outputs": [],
107109
"source": [
108110
"import tensorflow as tf # Deep Learning library\n",
@@ -150,7 +152,9 @@
150152
{
151153
"cell_type": "code",
152154
"execution_count": 2,
153-
"metadata": {},
155+
"metadata": {
156+
"collapsed": true
157+
},
154158
"outputs": [],
155159
"source": [
156160
"\"\"\"\n",
@@ -208,7 +212,9 @@
208212
{
209213
"cell_type": "code",
210214
"execution_count": 3,
211-
"metadata": {},
215+
"metadata": {
216+
"collapsed": true
217+
},
212218
"outputs": [],
213219
"source": [
214220
"game, possible_actions = create_environment()"
@@ -232,7 +238,9 @@
232238
{
233239
"cell_type": "code",
234240
"execution_count": 4,
235-
"metadata": {},
241+
"metadata": {
242+
"collapsed": true
243+
},
236244
"outputs": [],
237245
"source": [
238246
"\"\"\"\n",
@@ -299,7 +307,9 @@
299307
{
300308
"cell_type": "code",
301309
"execution_count": 5,
302-
"metadata": {},
310+
"metadata": {
311+
"collapsed": true
312+
},
303313
"outputs": [],
304314
"source": [
305315
"stack_size = 4 # We stack 4 frames\n",
@@ -348,7 +358,9 @@
348358
{
349359
"cell_type": "code",
350360
"execution_count": 6,
351-
"metadata": {},
361+
"metadata": {
362+
"collapsed": true
363+
},
352364
"outputs": [],
353365
"source": [
354366
"### MODEL HYPERPARAMETERS\n",
@@ -397,7 +409,9 @@
397409
{
398410
"cell_type": "code",
399411
"execution_count": 7,
400-
"metadata": {},
412+
"metadata": {
413+
"collapsed": true
414+
},
401415
"outputs": [],
402416
"source": [
403417
"class DQNetwork:\n",
@@ -517,7 +531,9 @@
517531
{
518532
"cell_type": "code",
519533
"execution_count": 8,
520-
"metadata": {},
534+
"metadata": {
535+
"collapsed": true
536+
},
521537
"outputs": [],
522538
"source": [
523539
"# Reset the graph\n",
@@ -541,7 +557,9 @@
541557
{
542558
"cell_type": "code",
543559
"execution_count": 9,
544-
"metadata": {},
560+
"metadata": {
561+
"collapsed": true
562+
},
545563
"outputs": [],
546564
"source": [
547565
"class Memory():\n",
@@ -570,7 +588,9 @@
570588
{
571589
"cell_type": "code",
572590
"execution_count": 10,
573-
"metadata": {},
591+
"metadata": {
592+
"collapsed": true
593+
},
574594
"outputs": [],
575595
"source": [
576596
"# Instantiate memory\n",
@@ -636,7 +656,9 @@
636656
{
637657
"cell_type": "code",
638658
"execution_count": 11,
639-
"metadata": {},
659+
"metadata": {
660+
"collapsed": true
661+
},
640662
"outputs": [],
641663
"source": [
642664
"# Setup TensorBoard Writer\n",
@@ -683,7 +705,9 @@
683705
{
684706
"cell_type": "code",
685707
"execution_count": 12,
686-
"metadata": {},
708+
"metadata": {
709+
"collapsed": true
710+
},
687711
"outputs": [],
688712
"source": [
689713
"\"\"\"\n",
@@ -885,7 +909,9 @@
885909
{
886910
"cell_type": "code",
887911
"execution_count": null,
888-
"metadata": {},
912+
"metadata": {
913+
"collapsed": true
914+
},
889915
"outputs": [],
890916
"source": [
891917
"with tf.Session() as sess:\n",
@@ -894,32 +920,48 @@
894920
" \n",
895921
" totalScore = 0\n",
896922
" \n",
897-
" \n",
898923
" # Load the model\n",
899924
" saver.restore(sess, \"./models/model.ckpt\")\n",
900925
" game.init()\n",
901926
" for i in range(1):\n",
902927
" \n",
928+
" done = False\n",
929+
" \n",
903930
" game.new_episode()\n",
931+
" \n",
932+
" state = game.get_state().screen_buffer\n",
933+
" state, stacked_frames = stack_frames(stacked_frames, state, True)\n",
934+
" \n",
904935
" while not game.is_episode_finished():\n",
905-
" frame = game.get_state().screen_buffer\n",
906-
" state = stack_frames(stacked_frames, frame)\n",
907936
" # Take the biggest Q value (= the best action)\n",
908937
" Qs = sess.run(DQNetwork.output, feed_dict = {DQNetwork.inputs_: state.reshape((1, *state.shape))})\n",
909-
" action = np.argmax(Qs)\n",
910-
" action = possible_actions[int(action)]\n",
911-
" game.make_action(action) \n",
938+
" \n",
939+
" # Take the biggest Q value (= the best action)\n",
940+
" choice = np.argmax(Qs)\n",
941+
" action = possible_actions[int(choice)]\n",
942+
" \n",
943+
" game.make_action(action)\n",
944+
" done = game.is_episode_finished()\n",
912945
" score = game.get_total_reward()\n",
946+
" \n",
947+
" if done:\n",
948+
" break \n",
949+
" \n",
950+
" else:\n",
951+
" print(\"else\")\n",
952+
" next_state = game.get_state().screen_buffer\n",
953+
" next_state, stacked_frames = stack_frames(stacked_frames, next_state, False)\n",
954+
" state = next_state\n",
955+
" \n",
956+
" score = game.get_total_reward()\n",
913957
" print(\"Score: \", score)\n",
914-
" totalScore += score\n",
915-
" print(\"TOTAL_SCORE\", totalScore/100.0)\n",
916958
" game.close()"
917959
]
918960
}
919961
],
920962
"metadata": {
921963
"kernelspec": {
922-
"display_name": "Python [default]",
964+
"display_name": "Python 3",
923965
"language": "python",
924966
"name": "python3"
925967
},

0 commit comments

Comments
 (0)