Skip to content

Commit

Permalink
[C++ API] Remove virtual forward and implement Sequential based on An…
Browse files Browse the repository at this point in the history
…y(Module) (pytorch#7508)

* Remove virtual forward

* Rebase
  • Loading branch information
goldsborough authored May 24, 2018
1 parent 1078491 commit b121640
Show file tree
Hide file tree
Showing 29 changed files with 1,489 additions and 231 deletions.
1 change: 1 addition & 0 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Checks: '
,-cert-err58-cpp
,-modernize-make-unique
,-cppcoreguidelines-owning-memory
,-readability-named-parameter
'
WarningsAsErrors: ''
HeaderFilterRegex: 'torch/csrc/'
Expand Down
319 changes: 319 additions & 0 deletions test/cpp/api/any.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
#include <catch.hpp>

#include <torch/torch.h>

#include <torch/nn/modules/any.h>

#include <algorithm>
#include <string>

using namespace torch;
using namespace torch::nn;
using namespace torch::detail;

using Catch::Contains;
using Catch::StartsWith;

TEST_CASE("any-module") {
SECTION("int()") {
struct M : nn::Module {
int forward() {
return 123;
}
};
AnyModule any(M{});
REQUIRE(any.forward().get<int>() == 123);
}
SECTION("int(int)") {
struct M : nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any(M{});
REQUIRE(any.forward(5).get<int>() == 5);
}
SECTION("const char*(const char*)") {
struct M : nn::Module {
const char* forward(const char* x) {
return x;
}
};
AnyModule any(M{});
REQUIRE(any.forward("hello").get<const char*>() == std::string("hello"));
}

SECTION("string(int, const double)") {
struct M : nn::Module {
std::string forward(int x, const double f) {
return std::to_string(static_cast<int>(x + f));
}
};
AnyModule any(M{});
int x = 4;
REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7"));
}

SECTION("Variable(string, const string&, string&&)") {
struct M : nn::Module {
autograd::Variable forward(
std::string a,
const std::string& b,
std::string&& c) {
const auto s = a + b + c;
return autograd::make_variable(
at::ones(at::CPU(at::kFloat), {static_cast<int64_t>(s.size())}));
}
};
AnyModule any(M{});
REQUIRE(
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
.get<autograd::Variable>()
.sum()
.toCInt() == 6);
}
SECTION("wrong argument type") {
struct M : nn::Module {
int forward(float x) {
return x;
}
};
AnyModule any(M{});
REQUIRE_THROWS_WITH(
any.forward(5.0),
StartsWith("Expected argument #0 to be of type float, "
"but received value of type double"));
}
SECTION("wrong number of arguments") {
struct M : nn::Module {
int forward(int a, int b) {
return a + b;
}
};
AnyModule any(M{});
REQUIRE_THROWS_WITH(
any.forward(),
Contains("M's forward() method expects 2 arguments, but received 0"));
REQUIRE_THROWS_WITH(
any.forward(5),
Contains("M's forward() method expects 2 arguments, but received 1"));
REQUIRE_THROWS_WITH(
any.forward(1, 2, 3),
Contains("M's forward() method expects 2 arguments, but received 3"));
}
SECTION("get()") {
struct M : nn::Module {
explicit M(int value_) : nn::Module("M"), value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any(M{5});

SECTION("good cast") {
REQUIRE(any.get<M>().value == 5);
}

SECTION("bad cast") {
struct N : nn::Module {};
REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
}
}
SECTION("ptr()") {
struct M : nn::Module {
explicit M(int value_) : nn::Module("M"), value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any(M{5});

SECTION("base class cast") {
auto ptr = any.ptr();
REQUIRE(ptr != nullptr);
REQUIRE(ptr->name() == "M");
}

SECTION("good downcast") {
auto ptr = any.ptr<M>();
REQUIRE(ptr != nullptr);
REQUIRE(ptr->value == 5);
}

SECTION("bad downcast") {
struct N : nn::Module {};
REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
}
}
SECTION("default state is empty") {
struct M : nn::Module {
explicit M(int value_) : value(value_) {}
int value;
int forward(float x) {
return x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
any = std::make_shared<M>(5);
REQUIRE(!any.is_empty());
REQUIRE(any.get<M>().value == 5);
}
SECTION("all methods throw for empty AnyModule") {
struct M : nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
REQUIRE_THROWS_WITH(
any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.type_info(),
StartsWith("Cannot call type_info() on an empty AnyModule"));
REQUIRE_THROWS_WITH(
any.forward<int>(5),
StartsWith("Cannot call forward() on an empty AnyModule"));
}
SECTION("can move assign differentm modules") {
struct M : nn::Module {
std::string forward(int x) {
return std::to_string(x);
}
};
struct N : nn::Module {
int forward(float x) {
return 3 + x;
}
};
AnyModule any;
REQUIRE(any.is_empty());
any = std::make_shared<M>();
REQUIRE(!any.is_empty());
REQUIRE(any.forward(5).get<std::string>() == "5");
any = std::make_shared<N>();
REQUIRE(!any.is_empty());
REQUIRE(any.forward(5.0f).get<int>() == 8);
}
SECTION("has reference semantics") {
Sequential first(
Linear(2, 3).build(), Linear(4, 4).build(), Linear(4, 5).build());
Sequential second(first);

REQUIRE(first.size() == second.size());
REQUIRE(std::equal(first.begin(), first.end(), second.begin()));
}
}

namespace torch {
namespace nn {
struct TestValue {
template <typename T>
explicit TestValue(T&& value) : value_(std::forward<T>(value)) {}
AnyModule::Value operator()() {
return std::move(value_);
}
AnyModule::Value value_;
};
template <typename T>
AnyModule::Value make_value(T&& value) {
return TestValue(std::forward<T>(value))();
}
} // namespace nn
} // namespace torch

TEST_CASE("any-value") {
SECTION("gets the correct value for the right type") {
SECTION("int") {
auto value = make_value(5);
// const and non-const types have the same typeid()
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.try_get<const int>() != nullptr);
REQUIRE(value.get<int>() == 5);
}
SECTION("const int") {
auto value = make_value(5);
REQUIRE(value.try_get<const int>() != nullptr);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.get<const int>() == 5);
}
SECTION("const char*") {
auto value = make_value("hello");
REQUIRE(value.try_get<const char*>() != nullptr);
REQUIRE(value.get<const char*>() == std::string("hello"));
}
SECTION("std::string") {
auto value = make_value(std::string("hello"));
REQUIRE(value.try_get<std::string>() != nullptr);
REQUIRE(value.get<std::string>() == "hello");
}
SECTION("pointers") {
std::string s("hello");
std::string* p = &s;
auto value = make_value(p);
REQUIRE(value.try_get<std::string*>() != nullptr);
REQUIRE(*value.get<std::string*>() == "hello");
}
SECTION("references") {
std::string s("hello");
const std::string& t = s;
auto value = make_value(t);
REQUIRE(value.try_get<std::string>() != nullptr);
REQUIRE(value.get<std::string>() == "hello");
}
}
SECTION("try_get returns nullptr for the wrong type") {
auto value = make_value(5);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE(value.try_get<float>() == nullptr);
REQUIRE(value.try_get<long>() == nullptr);
REQUIRE(value.try_get<std::string>() == nullptr);
}
SECTION("get throws for the wrong type") {
auto value = make_value(5);
REQUIRE(value.try_get<int>() != nullptr);
REQUIRE_THROWS_WITH(
value.get<float>(),
StartsWith("Attempted to cast Value to float, "
"but its actual type is int"));
REQUIRE_THROWS_WITH(
value.get<long>(),
StartsWith("Attempted to cast Value to long, "
"but its actual type is int"));
}
SECTION("move is allowed") {
auto value = make_value(5);
SECTION("construction") {
auto copy = make_value(std::move(value));
REQUIRE(copy.try_get<int>() != nullptr);
REQUIRE(copy.get<int>() == 5);
}
SECTION("assignment") {
auto copy = make_value(10);
copy = std::move(value);
REQUIRE(copy.try_get<int>() != nullptr);
REQUIRE(copy.get<int>() == 5);
}
}
SECTION("type_info is correct") {
SECTION("int") {
auto value = make_value(5);
REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
}
SECTION("const char") {
auto value = make_value("hello");
REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
}
SECTION("std::string") {
auto value = make_value(std::string("hello"));
REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
}
}
}
Loading

0 comments on commit b121640

Please sign in to comment.