forked from google-deepmind/lab
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcallbacks_test.cc
More file actions
103 lines (87 loc) · 4.03 KB
/
callbacks_test.cc
File metadata and controls
103 lines (87 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Copyright (C) 2016 Google Inc.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License along
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
//
////////////////////////////////////////////////////////////////////////////////
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "deepmind/include/deepmind_context.h"
#include "deepmind/support/test_srcdir.h"
namespace {
using ::deepmind::lab::TestSrcDir;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
TEST(DeepmindCallbackTest, CreateAndDestroyContext) {
DeepmindContext ctx{};
const char arg0[] = "dmlab";
ASSERT_EQ(0, dmlab_create_context(TestSrcDir().c_str(), &ctx));
ASSERT_EQ(0, ctx.hooks.set_script_name(ctx.userdata, "tests/callbacks_test"));
ctx.hooks.add_setting(ctx.userdata, "command", "hello");
ASSERT_EQ(0, ctx.hooks.init(ctx.userdata));
ASSERT_EQ(0, ctx.hooks.start(ctx.userdata, 0, 0));
const char* cmd_line = ctx.hooks.replace_command_line(ctx.userdata, arg0);
EXPECT_THAT(cmd_line, ::testing::HasSubstr("hello"));
EXPECT_STREQ("lt_chasm_1", ctx.hooks.next_map(ctx.userdata));
EXPECT_STREQ("lt_chasm_2", ctx.hooks.next_map(ctx.userdata));
EXPECT_EQ(1, ctx.hooks.run_lua_snippet(
ctx.userdata, "return (...):commandLine('') and 1 or 0"));
dmlab_release_context(&ctx);
}
TEST(DeepmindCallbackTest, CustomObservations) {
DeepmindContext ctx{};
const char callbacks_test[] = "tests/callbacks_test";
const char order[] = "Find Apples!";
ASSERT_EQ(0, dmlab_create_context(TestSrcDir().c_str(), &ctx));
ctx.hooks.add_setting(ctx.userdata, "order", order);
ASSERT_EQ(0, ctx.hooks.set_script_name(ctx.userdata, callbacks_test));
ASSERT_EQ(0, ctx.hooks.init(ctx.userdata));
ASSERT_EQ(3, ctx.hooks.custom_observation_count(ctx.userdata));
EnvCApi_ObservationSpec spec;
EXPECT_STREQ("LOCATION", ctx.hooks.custom_observation_name(ctx.userdata, 0));
ctx.hooks.custom_observation_spec(ctx.userdata, 0, &spec);
EXPECT_EQ(1, spec.dims);
EXPECT_EQ(EnvCApi_ObservationDoubles, spec.type);
EXPECT_EQ(3, spec.shape[0]);
EXPECT_STREQ("ORDER", ctx.hooks.custom_observation_name(ctx.userdata, 1));
ctx.hooks.custom_observation_spec(ctx.userdata, 1, &spec);
EXPECT_EQ(1, spec.dims);
EXPECT_EQ(EnvCApi_ObservationBytes, spec.type);
EXPECT_EQ(0, spec.shape[0]);
EXPECT_STREQ("EPISODE", ctx.hooks.custom_observation_name(ctx.userdata, 2));
ctx.hooks.custom_observation_spec(ctx.userdata, 2, &spec);
EXPECT_EQ(1, spec.dims);
EXPECT_EQ(EnvCApi_ObservationDoubles, spec.type);
EXPECT_EQ(1, spec.shape[0]);
const int episode = 10;
ASSERT_EQ(0, ctx.hooks.start(ctx.userdata, episode, 0));
EnvCApi_Observation obs;
ctx.hooks.custom_observation(ctx.userdata, 0, &obs);
ASSERT_EQ(1, obs.spec.dims);
ASSERT_EQ(EnvCApi_ObservationDoubles, obs.spec.type);
EXPECT_THAT(std::make_tuple(obs.payload.doubles, obs.spec.shape[0]),
ElementsAre(10.0, 20.0, 30.0));
ctx.hooks.custom_observation(ctx.userdata, 1, &obs);
ASSERT_EQ(1, obs.spec.dims);
ASSERT_EQ(EnvCApi_ObservationBytes, obs.spec.type);
EXPECT_THAT(std::make_tuple(obs.payload.bytes, obs.spec.shape[0]),
ElementsAreArray(order, sizeof(order) - 1));
ctx.hooks.custom_observation(ctx.userdata, 2, &obs);
ASSERT_EQ(1, obs.spec.dims);
ASSERT_EQ(EnvCApi_ObservationDoubles, obs.spec.type);
EXPECT_THAT(std::make_tuple(obs.payload.doubles, obs.spec.shape[0]),
ElementsAre(episode));
dmlab_release_context(&ctx);
}
} // namespace