|
102 | 102 | {
|
103 | 103 | "cell_type": "code",
|
104 | 104 | "execution_count": 1,
|
105 |
| - "metadata": {}, |
| 105 | + "metadata": { |
| 106 | + "collapsed": true |
| 107 | + }, |
106 | 108 | "outputs": [],
|
107 | 109 | "source": [
|
108 | 110 | "import tensorflow as tf # Deep Learning library\n",
|
|
150 | 152 | {
|
151 | 153 | "cell_type": "code",
|
152 | 154 | "execution_count": 2,
|
153 |
| - "metadata": {}, |
| 155 | + "metadata": { |
| 156 | + "collapsed": true |
| 157 | + }, |
154 | 158 | "outputs": [],
|
155 | 159 | "source": [
|
156 | 160 | "\"\"\"\n",
|
|
208 | 212 | {
|
209 | 213 | "cell_type": "code",
|
210 | 214 | "execution_count": 3,
|
211 |
| - "metadata": {}, |
| 215 | + "metadata": { |
| 216 | + "collapsed": true |
| 217 | + }, |
212 | 218 | "outputs": [],
|
213 | 219 | "source": [
|
214 | 220 | "game, possible_actions = create_environment()"
|
|
232 | 238 | {
|
233 | 239 | "cell_type": "code",
|
234 | 240 | "execution_count": 4,
|
235 |
| - "metadata": {}, |
| 241 | + "metadata": { |
| 242 | + "collapsed": true |
| 243 | + }, |
236 | 244 | "outputs": [],
|
237 | 245 | "source": [
|
238 | 246 | "\"\"\"\n",
|
|
299 | 307 | {
|
300 | 308 | "cell_type": "code",
|
301 | 309 | "execution_count": 5,
|
302 |
| - "metadata": {}, |
| 310 | + "metadata": { |
| 311 | + "collapsed": true |
| 312 | + }, |
303 | 313 | "outputs": [],
|
304 | 314 | "source": [
|
305 | 315 | "stack_size = 4 # We stack 4 frames\n",
|
|
348 | 358 | {
|
349 | 359 | "cell_type": "code",
|
350 | 360 | "execution_count": 6,
|
351 |
| - "metadata": {}, |
| 361 | + "metadata": { |
| 362 | + "collapsed": true |
| 363 | + }, |
352 | 364 | "outputs": [],
|
353 | 365 | "source": [
|
354 | 366 | "### MODEL HYPERPARAMETERS\n",
|
|
397 | 409 | {
|
398 | 410 | "cell_type": "code",
|
399 | 411 | "execution_count": 7,
|
400 |
| - "metadata": {}, |
| 412 | + "metadata": { |
| 413 | + "collapsed": true |
| 414 | + }, |
401 | 415 | "outputs": [],
|
402 | 416 | "source": [
|
403 | 417 | "class DQNetwork:\n",
|
|
517 | 531 | {
|
518 | 532 | "cell_type": "code",
|
519 | 533 | "execution_count": 8,
|
520 |
| - "metadata": {}, |
| 534 | + "metadata": { |
| 535 | + "collapsed": true |
| 536 | + }, |
521 | 537 | "outputs": [],
|
522 | 538 | "source": [
|
523 | 539 | "# Reset the graph\n",
|
|
541 | 557 | {
|
542 | 558 | "cell_type": "code",
|
543 | 559 | "execution_count": 9,
|
544 |
| - "metadata": {}, |
| 560 | + "metadata": { |
| 561 | + "collapsed": true |
| 562 | + }, |
545 | 563 | "outputs": [],
|
546 | 564 | "source": [
|
547 | 565 | "class Memory():\n",
|
|
570 | 588 | {
|
571 | 589 | "cell_type": "code",
|
572 | 590 | "execution_count": 10,
|
573 |
| - "metadata": {}, |
| 591 | + "metadata": { |
| 592 | + "collapsed": true |
| 593 | + }, |
574 | 594 | "outputs": [],
|
575 | 595 | "source": [
|
576 | 596 | "# Instantiate memory\n",
|
|
636 | 656 | {
|
637 | 657 | "cell_type": "code",
|
638 | 658 | "execution_count": 11,
|
639 |
| - "metadata": {}, |
| 659 | + "metadata": { |
| 660 | + "collapsed": true |
| 661 | + }, |
640 | 662 | "outputs": [],
|
641 | 663 | "source": [
|
642 | 664 | "# Setup TensorBoard Writer\n",
|
|
683 | 705 | {
|
684 | 706 | "cell_type": "code",
|
685 | 707 | "execution_count": 12,
|
686 |
| - "metadata": {}, |
| 708 | + "metadata": { |
| 709 | + "collapsed": true |
| 710 | + }, |
687 | 711 | "outputs": [],
|
688 | 712 | "source": [
|
689 | 713 | "\"\"\"\n",
|
|
885 | 909 | {
|
886 | 910 | "cell_type": "code",
|
887 | 911 | "execution_count": null,
|
888 |
| - "metadata": {}, |
| 912 | + "metadata": { |
| 913 | + "collapsed": true |
| 914 | + }, |
889 | 915 | "outputs": [],
|
890 | 916 | "source": [
|
891 | 917 | "with tf.Session() as sess:\n",
|
|
894 | 920 | " \n",
|
895 | 921 | " totalScore = 0\n",
|
896 | 922 | " \n",
|
897 |
| - " \n", |
898 | 923 | " # Load the model\n",
|
899 | 924 | " saver.restore(sess, \"./models/model.ckpt\")\n",
|
900 | 925 | " game.init()\n",
|
901 | 926 | " for i in range(1):\n",
|
902 | 927 | " \n",
|
| 928 | + " done = False\n", |
| 929 | + " \n", |
903 | 930 | " game.new_episode()\n",
|
| 931 | + " \n", |
| 932 | + " state = game.get_state().screen_buffer\n", |
| 933 | + " state, stacked_frames = stack_frames(stacked_frames, state, True)\n", |
| 934 | + " \n", |
904 | 935 | " while not game.is_episode_finished():\n",
|
905 |
| - " frame = game.get_state().screen_buffer\n", |
906 |
| - " state = stack_frames(stacked_frames, frame)\n", |
907 | 936 | " # Take the biggest Q value (= the best action)\n",
|
908 | 937 | " Qs = sess.run(DQNetwork.output, feed_dict = {DQNetwork.inputs_: state.reshape((1, *state.shape))})\n",
|
909 |
| - " action = np.argmax(Qs)\n", |
910 |
| - " action = possible_actions[int(action)]\n", |
911 |
| - " game.make_action(action) \n", |
| 938 | + " \n", |
| 939 | + " # Take the biggest Q value (= the best action)\n", |
| 940 | + " choice = np.argmax(Qs)\n", |
| 941 | + " action = possible_actions[int(choice)]\n", |
| 942 | + " \n", |
| 943 | + " game.make_action(action)\n", |
| 944 | + " done = game.is_episode_finished()\n", |
912 | 945 | " score = game.get_total_reward()\n",
|
| 946 | + " \n", |
| 947 | + " if done:\n", |
| 948 | + " break \n", |
| 949 | + " \n", |
| 950 | + " else:\n", |
| 951 | + " print(\"else\")\n", |
| 952 | + " next_state = game.get_state().screen_buffer\n", |
| 953 | + " next_state, stacked_frames = stack_frames(stacked_frames, next_state, False)\n", |
| 954 | + " state = next_state\n", |
| 955 | + " \n", |
| 956 | + " score = game.get_total_reward()\n", |
913 | 957 | " print(\"Score: \", score)\n",
|
914 |
| - " totalScore += score\n", |
915 |
| - " print(\"TOTAL_SCORE\", totalScore/100.0)\n", |
916 | 958 | " game.close()"
|
917 | 959 | ]
|
918 | 960 | }
|
919 | 961 | ],
|
920 | 962 | "metadata": {
|
921 | 963 | "kernelspec": {
|
922 |
| - "display_name": "Python [default]", |
| 964 | + "display_name": "Python 3", |
923 | 965 | "language": "python",
|
924 | 966 | "name": "python3"
|
925 | 967 | },
|
|
0 commit comments