|
| 1 | +// Copyright (C) 2016 Google Inc. |
| 2 | +// |
| 3 | +// This program is free software; you can redistribute it and/or modify |
| 4 | +// it under the terms of the GNU General Public License as published by |
| 5 | +// the Free Software Foundation; either version 2 of the License, or |
| 6 | +// (at your option) any later version. |
| 7 | +// |
| 8 | +// This program is distributed in the hope that it will be useful, |
| 9 | +// but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 10 | +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 11 | +// GNU General Public License for more details. |
| 12 | +// |
| 13 | +// You should have received a copy of the GNU General Public License along |
| 14 | +// with this program; if not, write to the Free Software Foundation, Inc., |
| 15 | +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. |
| 16 | +// |
| 17 | +//////////////////////////////////////////////////////////////////////////////// |
| 18 | + |
| 19 | +#include <errno.h> |
| 20 | +#include <getopt.h> |
| 21 | +#include <limits.h> |
| 22 | +#include <stdbool.h> |
| 23 | +#include <stdarg.h> |
| 24 | +#include <stddef.h> |
| 25 | +#include <stdio.h> |
| 26 | +#include <stdlib.h> |
| 27 | +#include <string.h> |
| 28 | + |
| 29 | +#include "public/dmlab.h" |
| 30 | + |
| 31 | +static void __attribute__((noreturn, format(printf, 1, 2))) sys_error(const char* fmt, ...) { |
| 32 | + va_list ap; |
| 33 | + va_start(ap, fmt); |
| 34 | + vfprintf(stderr, fmt, ap); |
| 35 | + va_end(ap); |
| 36 | + fputc('\n', stderr); |
| 37 | + exit(EXIT_FAILURE); |
| 38 | +} |
| 39 | + |
| 40 | +// Returns whether val was successfully read from str. |
| 41 | +static bool parse_int(const char* str, int* val) { |
| 42 | + char* e; |
| 43 | + errno = 0; |
| 44 | + long int n = strtol(str, &e, 0); |
| 45 | + if (e != str && *e == '\0' && errno == 0 && INT_MIN <= n && n <= INT_MAX) { |
| 46 | + *val = n; |
| 47 | + return true; |
| 48 | + } |
| 49 | + return false; |
| 50 | +} |
| 51 | + |
| 52 | +static const char kUsage[] = |
| 53 | + "Interactive DeepMind Lab \"Game\"\n" |
| 54 | + "\n" |
| 55 | + "Usage: game --level_script <level> \\\n" |
| 56 | + " [--level_setting key=value [...]] \\\n" |
| 57 | + " [--num_episodes <N>] \\\n" |
| 58 | + " [--random_seed <S>]\n" |
| 59 | + "\n" |
| 60 | + " -l, --level_script: Mandatory. The level that is to be played. Levels are\n" |
| 61 | + " Lua scripts, and a script called \"name\" means that a\n" |
| 62 | + " file \"assets/game_scripts/name.lua\" is loaded.\n" |
| 63 | + " -s, --level_setting: Applies an opaque key-value setting. The setting is\n" |
| 64 | + " available to the level script. This flag may be provided\n" |
| 65 | + " multiple times.\n" |
| 66 | + " -e, --num_episodes: The number of episodes to play. Defaults to 1.\n" |
| 67 | + " -r, --random_seed: A seed value used for randomly generated content; using\n" |
| 68 | + " the same seed should result in the same content. Defaults\n" |
| 69 | + " to a fixed value.\n" |
| 70 | + ; |
| 71 | + |
| 72 | +static void process_commandline(int argc, char** argv, EnvCApi* env_c_api, |
| 73 | + void* context, int* num_episodes, int* seed) { |
| 74 | + static struct option long_options[] = { |
| 75 | + {"help", no_argument, NULL, 'h'}, |
| 76 | + {"level_script", required_argument, NULL, 'l'}, |
| 77 | + {"level_setting", required_argument, NULL, 's'}, |
| 78 | + {"num_episodes", required_argument, NULL, 'e'}, |
| 79 | + {"random_seed", required_argument, NULL, 'r'}, |
| 80 | + {NULL, 0, NULL, 0}}; |
| 81 | + |
| 82 | + char *key, *value; |
| 83 | + |
| 84 | + for (int c; (c = getopt_long(argc, argv, "hl:s:e:r:", long_options, 0)) != -1;) { |
| 85 | + switch (c) { |
| 86 | + case 'h': |
| 87 | + fputs(kUsage, stdout); |
| 88 | + exit(EXIT_SUCCESS); |
| 89 | + case 'l': |
| 90 | + if (env_c_api->setting(context, "levelName", optarg) != 0) { |
| 91 | + sys_error("Invalid levelName flag '%s'.", optarg); |
| 92 | + } |
| 93 | + break; |
| 94 | + case 's': |
| 95 | + key = optarg; |
| 96 | + value = strchr(optarg, '='); |
| 97 | + if (value == NULL) { |
| 98 | + sys_error("Setting must be in the form 'key=value'."); |
| 99 | + } |
| 100 | + value[0] = '\0'; |
| 101 | + ++value; |
| 102 | + if (env_c_api->setting(context, key, value) != 0) { |
| 103 | + sys_error("Failed to apply setting '%s = %s'.", key, value); |
| 104 | + } |
| 105 | + break; |
| 106 | + case 'e': |
| 107 | + if (!parse_int(optarg, num_episodes) || *num_episodes <= 0) { |
| 108 | + sys_error("Failed to set num_episodes to '%s'.", optarg); |
| 109 | + } |
| 110 | + break; |
| 111 | + case 'r': |
| 112 | + if (!parse_int(optarg, seed)) { |
| 113 | + sys_error("Failed to set random_seed to '%s'.", optarg); |
| 114 | + } |
| 115 | + break; |
| 116 | + case ':': |
| 117 | + case '?': |
| 118 | + default: |
| 119 | + sys_error("Bad command-line flag. Use --help for usage instructions."); |
| 120 | + } |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +int main(int argc, char** argv) { |
| 125 | + static const char kRunfiles[] = ".runfiles/org_deepmind_lab"; |
| 126 | + static EnvCApi env_c_api; |
| 127 | + static void* context; |
| 128 | + static char runfiles_path[4096]; |
| 129 | + |
| 130 | + if (sizeof(runfiles_path) < strlen(argv[0]) + sizeof(kRunfiles)) { |
| 131 | + sys_error("Runfiles directory name too long!"); |
| 132 | + } |
| 133 | + strcpy(runfiles_path, argv[0]); |
| 134 | + strcat(runfiles_path, kRunfiles); |
| 135 | + |
| 136 | + DeepMindLabLaunchParams params; |
| 137 | + params.runfiles_path = runfiles_path; |
| 138 | + if (dmlab_connect(¶ms, &env_c_api, &context) != 0) { |
| 139 | + sys_error("Failed to connect RL API"); |
| 140 | + } |
| 141 | + |
| 142 | + if (env_c_api.setting(context, "width", "640") != 0) { |
| 143 | + sys_error("Failed to apply default 'width' setting."); |
| 144 | + } |
| 145 | + |
| 146 | + if (env_c_api.setting(context, "height", "480") != 0) { |
| 147 | + sys_error("Failed to apply default 'height' setting."); |
| 148 | + } |
| 149 | + |
| 150 | + if (env_c_api.setting(context, "controls", "internal") != 0) { |
| 151 | + sys_error("Failed to apply 'controls' setting."); |
| 152 | + } |
| 153 | + |
| 154 | + if (env_c_api.setting(context, "appendCommand", " +set com_maxfps \"250\"") |
| 155 | + != 0) { |
| 156 | + sys_error("Failed to apply 'appendCommand' setting."); |
| 157 | + } |
| 158 | + |
| 159 | + int num_episodes = 1; |
| 160 | + int seed = 1; |
| 161 | + process_commandline(argc, argv, &env_c_api, context, &num_episodes, &seed); |
| 162 | + |
| 163 | + if (env_c_api.init(context) != 0) { |
| 164 | + sys_error("Failed to init RL API"); |
| 165 | + } |
| 166 | + |
| 167 | + for (int episode = 0; episode < num_episodes; ++episode, ++seed) { |
| 168 | + if (env_c_api.start(context, episode, seed) != 0) { |
| 169 | + sys_error("Failed to start environment."); |
| 170 | + } |
| 171 | + printf("Episode: %d\n", episode); |
| 172 | + double score = 0; |
| 173 | + double reward; |
| 174 | + while (env_c_api.advance(context, 1, &reward) == |
| 175 | + EnvCApi_EnvironmentStatus_Running) { |
| 176 | + if (reward != 0.0) { |
| 177 | + score += reward; |
| 178 | + printf("Score: %f\n", score); |
| 179 | + fflush(stdout); |
| 180 | + } |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + env_c_api.release_context(context); |
| 185 | +} |
0 commit comments