Skip to content

Commit a8dcf8e

Browse files
committed
DeepMind hook system backend
1 parent 63e6a93 commit a8dcf8e

17 files changed

Lines changed: 3202 additions & 0 deletions

deepmind/engine/BUILD

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Description:
2+
# Hooks and callbacks to enable gameplay modification.
3+
4+
licenses(["restricted"]) # GPLv2
5+
6+
cc_library(
7+
name = "context",
8+
srcs = ["context.cc"],
9+
hdrs = ["context.h"],
10+
visibility = ["//:__pkg__"],
11+
deps = [
12+
":lua_maze_generation",
13+
":lua_random",
14+
":lua_text_level_maker",
15+
"//deepmind/include:context_headers",
16+
"//deepmind/lua",
17+
"//deepmind/lua:bind",
18+
"//deepmind/lua:call",
19+
"//deepmind/lua:class",
20+
"//deepmind/lua:n_results_or",
21+
"//deepmind/lua:push_script",
22+
"//deepmind/lua:read",
23+
"//deepmind/lua:table_ref",
24+
"//deepmind/lua:vm",
25+
"//deepmind/tensor:lua_tensor",
26+
],
27+
)
28+
29+
cc_library(
30+
name = "callbacks",
31+
srcs = ["callbacks.cc"],
32+
data = ["//:non_pk3_assets"],
33+
visibility = ["//:__pkg__"],
34+
deps = [
35+
":context",
36+
":lua_maze_generation",
37+
":lua_random",
38+
":lua_text_level_maker",
39+
"//deepmind/level_generation/text_level:lua_bindings",
40+
"//deepmind/lua:vm",
41+
"//deepmind/tensor:lua_tensor",
42+
],
43+
)
44+
45+
cc_test(
46+
name = "callbacks_test",
47+
size = "small",
48+
srcs = ["callbacks_test.cc"],
49+
deps = [
50+
":callbacks",
51+
"//deepmind/include:context_headers",
52+
"//deepmind/support:test_srcdir",
53+
"@googletest//:gtest_main",
54+
],
55+
)
56+
57+
cc_library(
58+
name = "lua_maze_generation",
59+
srcs = ["lua_maze_generation.cc"],
60+
hdrs = ["lua_maze_generation.h"],
61+
deps = [
62+
"//deepmind/level_generation/text_level:char_grid",
63+
"//deepmind/level_generation/text_maze_generation:algorithm",
64+
"//deepmind/level_generation/text_maze_generation:text_maze",
65+
"//deepmind/lua",
66+
"//deepmind/lua:bind",
67+
"//deepmind/lua:call",
68+
"//deepmind/lua:class",
69+
"//deepmind/lua:n_results_or",
70+
"//deepmind/lua:push",
71+
"//deepmind/lua:read",
72+
"//deepmind/lua:table_ref",
73+
],
74+
)
75+
76+
cc_test(
77+
name = "lua_maze_generation_test",
78+
size = "small",
79+
srcs = ["lua_maze_generation_test.cc"],
80+
deps = [
81+
":lua_maze_generation",
82+
"//deepmind/lua:call",
83+
"//deepmind/lua:n_results_or_test_util",
84+
"//deepmind/lua:push_script",
85+
"//deepmind/lua:vm_test_util",
86+
"@googletest//:gtest_main",
87+
],
88+
)
89+
90+
cc_library(
91+
name = "lua_random",
92+
srcs = ["lua_random.cc"],
93+
hdrs = ["lua_random.h"],
94+
deps = [
95+
"//deepmind/lua",
96+
"//deepmind/lua:class",
97+
"//deepmind/lua:n_results_or",
98+
"//deepmind/lua:push",
99+
"//deepmind/lua:read",
100+
],
101+
)
102+
103+
cc_test(
104+
name = "lua_random_test",
105+
size = "small",
106+
srcs = ["lua_random_test.cc"],
107+
deps = [
108+
":lua_random",
109+
"//deepmind/lua:call",
110+
"//deepmind/lua:n_results_or_test_util",
111+
"//deepmind/lua:push_script",
112+
"//deepmind/lua:vm_test_util",
113+
"@googletest//:gtest_main",
114+
],
115+
)
116+
117+
cc_library(
118+
name = "lua_text_level_maker",
119+
srcs = ["lua_text_level_maker.cc"],
120+
hdrs = ["lua_text_level_maker.h"],
121+
visibility = ["//deepmind:__subpackages__"],
122+
deps = [
123+
":lua_random",
124+
"//deepmind/level_generation:compile_map",
125+
"//deepmind/level_generation/text_level:lua_bindings",
126+
"//deepmind/lua",
127+
"//deepmind/lua:class",
128+
"//deepmind/lua:n_results_or",
129+
"//deepmind/lua:push",
130+
"//deepmind/lua:read",
131+
"//deepmind/lua:table_ref",
132+
],
133+
)

deepmind/engine/callbacks.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (C) 2016 Google Inc.
2+
//
3+
// This program is free software; you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License as published by
5+
// the Free Software Foundation; either version 2 of the License, or
6+
// (at your option) any later version.
7+
//
8+
// This program is distributed in the hope that it will be useful,
9+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
// GNU General Public License for more details.
12+
//
13+
// You should have received a copy of the GNU General Public License along
14+
// with this program; if not, write to the Free Software Foundation, Inc.,
15+
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16+
//
17+
////////////////////////////////////////////////////////////////////////////////
18+
19+
#include <algorithm>
20+
#include <cstring>
21+
#include <iostream>
22+
#include <utility>
23+
24+
#include "deepmind/engine/context.h"
25+
#include "deepmind/engine/lua_maze_generation.h"
26+
#include "deepmind/engine/lua_random.h"
27+
#include "deepmind/engine/lua_text_level_maker.h"
28+
#include "deepmind/include/deepmind_context.h"
29+
#include "deepmind/level_generation/text_level/lua_bindings.h"
30+
#include "deepmind/lua/vm.h"
31+
#include "deepmind/tensor/lua_tensor.h"
32+
33+
namespace lua = deepmind::lab::lua;
34+
namespace tensor = deepmind::lab::tensor;
35+
using deepmind::lab::Context;
36+
using deepmind::lab::LuaMazeGeneration;
37+
using deepmind::lab::LuaRandom;
38+
using deepmind::lab::LuaSnippetEmitter;
39+
using deepmind::lab::LuaTextLevelMaker;
40+
41+
extern "C" {
42+
43+
int dmlab_create_context(const char* runfiles_path, DeepmindContext* ctx) {
44+
lua::Vm lua_vm = lua::CreateVm();
45+
lua_State* L = lua_vm.get();
46+
tensor::LuaTensorRegister(L);
47+
LuaMazeGeneration::Register(L);
48+
LuaRandom::Register(L);
49+
LuaTextLevelMaker::Register(L);
50+
LuaSnippetEmitter::Register(L);
51+
52+
ctx->userdata =
53+
new Context(std::move(lua_vm), runfiles_path, &ctx->calls, &ctx->hooks);
54+
return 0;
55+
}
56+
57+
void dmlab_release_context(DeepmindContext* ctx) {
58+
delete static_cast<Context*>(ctx->userdata);
59+
}
60+
61+
} // extern "C"

deepmind/engine/callbacks_test.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright (C) 2016 Google Inc.
2+
//
3+
// This program is free software; you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License as published by
5+
// the Free Software Foundation; either version 2 of the License, or
6+
// (at your option) any later version.
7+
//
8+
// This program is distributed in the hope that it will be useful,
9+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
// GNU General Public License for more details.
12+
//
13+
// You should have received a copy of the GNU General Public License along
14+
// with this program; if not, write to the Free Software Foundation, Inc.,
15+
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16+
//
17+
////////////////////////////////////////////////////////////////////////////////
18+
19+
#include "gmock/gmock.h"
20+
#include "gtest/gtest.h"
21+
#include "deepmind/include/deepmind_context.h"
22+
#include "deepmind/support/test_srcdir.h"
23+
24+
namespace {
25+
26+
using ::deepmind::lab::TestSrcDir;
27+
using ::testing::ElementsAre;
28+
using ::testing::ElementsAreArray;
29+
30+
TEST(DeepmindCallbackTest, CreateAndDestroyContext) {
31+
DeepmindContext ctx{};
32+
const char arg0[] = "dmlab";
33+
ASSERT_EQ(0, dmlab_create_context(TestSrcDir().c_str(), &ctx));
34+
ASSERT_EQ(0, ctx.hooks.set_script_name(ctx.userdata, "tests/callbacks_test"));
35+
36+
ctx.hooks.add_setting(ctx.userdata, "command", "hello");
37+
ASSERT_EQ(0, ctx.hooks.init(ctx.userdata));
38+
ASSERT_EQ(0, ctx.hooks.start(ctx.userdata, 0, 0));
39+
40+
const char* cmd_line = ctx.hooks.replace_command_line(ctx.userdata, arg0);
41+
EXPECT_THAT(cmd_line, ::testing::HasSubstr("hello"));
42+
EXPECT_STREQ("lt_chasm_1", ctx.hooks.next_map(ctx.userdata));
43+
EXPECT_STREQ("lt_chasm_2", ctx.hooks.next_map(ctx.userdata));
44+
EXPECT_EQ(1, ctx.hooks.run_lua_snippet(
45+
ctx.userdata, "return (...):commandLine('') and 1 or 0"));
46+
dmlab_release_context(&ctx);
47+
}
48+
49+
TEST(DeepmindCallbackTest, CustomObservations) {
50+
DeepmindContext ctx{};
51+
const char callbacks_test[] = "tests/callbacks_test";
52+
const char order[] = "Find Apples!";
53+
ASSERT_EQ(0, dmlab_create_context(TestSrcDir().c_str(), &ctx));
54+
ctx.hooks.add_setting(ctx.userdata, "order", order);
55+
ASSERT_EQ(0, ctx.hooks.set_script_name(ctx.userdata, callbacks_test));
56+
ASSERT_EQ(0, ctx.hooks.init(ctx.userdata));
57+
ASSERT_EQ(3, ctx.hooks.custom_observation_count(ctx.userdata));
58+
EnvCApi_ObservationSpec spec;
59+
60+
EXPECT_STREQ("LOCATION", ctx.hooks.custom_observation_name(ctx.userdata, 0));
61+
ctx.hooks.custom_observation_spec(ctx.userdata, 0, &spec);
62+
EXPECT_EQ(1, spec.dims);
63+
EXPECT_EQ(EnvCApi_ObservationDoubles, spec.type);
64+
EXPECT_EQ(3, spec.shape[0]);
65+
66+
EXPECT_STREQ("ORDER", ctx.hooks.custom_observation_name(ctx.userdata, 1));
67+
ctx.hooks.custom_observation_spec(ctx.userdata, 1, &spec);
68+
EXPECT_EQ(1, spec.dims);
69+
EXPECT_EQ(EnvCApi_ObservationBytes, spec.type);
70+
EXPECT_EQ(0, spec.shape[0]);
71+
72+
EXPECT_STREQ("EPISODE", ctx.hooks.custom_observation_name(ctx.userdata, 2));
73+
ctx.hooks.custom_observation_spec(ctx.userdata, 2, &spec);
74+
EXPECT_EQ(1, spec.dims);
75+
EXPECT_EQ(EnvCApi_ObservationDoubles, spec.type);
76+
EXPECT_EQ(1, spec.shape[0]);
77+
78+
const int episode = 10;
79+
ASSERT_EQ(0, ctx.hooks.start(ctx.userdata, episode, 0));
80+
81+
EnvCApi_Observation obs;
82+
ctx.hooks.custom_observation(ctx.userdata, 0, &obs);
83+
ASSERT_EQ(1, obs.spec.dims);
84+
ASSERT_EQ(EnvCApi_ObservationDoubles, obs.spec.type);
85+
EXPECT_THAT(std::make_tuple(obs.payload.doubles, obs.spec.shape[0]),
86+
ElementsAre(10.0, 20.0, 30.0));
87+
88+
ctx.hooks.custom_observation(ctx.userdata, 1, &obs);
89+
ASSERT_EQ(1, obs.spec.dims);
90+
ASSERT_EQ(EnvCApi_ObservationBytes, obs.spec.type);
91+
EXPECT_THAT(std::make_tuple(obs.payload.bytes, obs.spec.shape[0]),
92+
ElementsAreArray(order, sizeof(order) - 1));
93+
94+
ctx.hooks.custom_observation(ctx.userdata, 2, &obs);
95+
ASSERT_EQ(1, obs.spec.dims);
96+
ASSERT_EQ(EnvCApi_ObservationDoubles, obs.spec.type);
97+
EXPECT_THAT(std::make_tuple(obs.payload.doubles, obs.spec.shape[0]),
98+
ElementsAre(episode));
99+
100+
dmlab_release_context(&ctx);
101+
}
102+
103+
} // namespace

0 commit comments

Comments
 (0)