Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 9eda405

Browse files
author
Isaac Poulton
committed
Add Bernoulli output layer
1 parent 485e5b9 commit 9eda405

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ add_custom_target(
3737

3838
# Dependencies
3939
## PyTorch
40-
find_package(Torch REQUIRED)
41-
if (TORCH_CXX_FLAGS)
42-
set(CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS})
43-
endif()
40+
if (NOT TORCH_FOUND)
41+
find_package(Torch REQUIRED)
42+
if (TORCH_CXX_FLAGS)
43+
set(CMAKE_CXX_FLAGS ${TORCH_CXX_FLAGS})
44+
endif()
45+
endif (NOT TORCH_FOUND)
4446

4547
# Define targets
4648
add_library(cpprl STATIC "")

include/cpprl/model/output_layers.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ class OutputLayer : public nn::Module
2020

2121
inline OutputLayer::~OutputLayer() {}
2222

23+
class BernoulliOutput : public OutputLayer
24+
{
25+
private:
26+
nn::Linear linear;
27+
28+
public:
29+
BernoulliOutput(unsigned int num_inputs, unsigned int num_outputs);
30+
31+
std::unique_ptr<Distribution> forward(torch::Tensor x);
32+
};
33+
2334
class CategoricalOutput : public OutputLayer
2435
{
2536
private:

src/model/output_layers.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "cpprl/model/output_layers.h"
66
#include "cpprl/model/model_utils.h"
77
#include "cpprl/distributions/distribution.h"
8+
#include "cpprl/distributions/bernoulli.h"
89
#include "cpprl/distributions/categorical.h"
910
#include "cpprl/distributions/normal.h"
1011
#include "third_party/doctest.h"
@@ -13,6 +14,20 @@ using namespace torch;
1314

1415
namespace cpprl
1516
{
17+
BernoulliOutput::BernoulliOutput(unsigned int num_inputs,
18+
unsigned int num_outputs)
19+
: linear(num_inputs, num_outputs)
20+
{
21+
register_module("linear", linear);
22+
init_weights(linear->named_parameters(), 0.01, 0);
23+
}
24+
25+
std::unique_ptr<Distribution> BernoulliOutput::forward(torch::Tensor x)
26+
{
27+
x = linear(x);
28+
return std::make_unique<Bernoulli>(nullptr, &x);
29+
}
30+
1631
CategoricalOutput::CategoricalOutput(unsigned int num_inputs,
1732
unsigned int num_outputs)
1833
: linear(num_inputs, num_outputs)
@@ -43,6 +58,24 @@ std::unique_ptr<Distribution> NormalOutput::forward(torch::Tensor x)
4358
return std::make_unique<Normal>(loc, scale);
4459
}
4560

61+
TEST_CASE("BernoulliOutput")
62+
{
63+
auto output_layer = BernoulliOutput(3, 5);
64+
65+
SUBCASE("Output distribution has correct output shape")
66+
{
67+
float input_array[2][3] = {{0, 1, 2}, {3, 4, 5}};
68+
auto input_tensor = torch::from_blob(input_array,
69+
{2, 3},
70+
TensorOptions(torch::kFloat));
71+
auto dist = output_layer.forward(input_tensor);
72+
73+
auto output = dist->sample();
74+
75+
CHECK(output.sizes().vec() == std::vector<int64_t>{2, 5});
76+
}
77+
}
78+
4679
TEST_CASE("CategoricalOutput")
4780
{
4881
auto output_layer = CategoricalOutput(3, 5);

src/model/policy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ PolicyImpl::PolicyImpl(ActionSpace action_space, std::shared_ptr<NNBase> base)
2929
}
3030
else if (action_space.type == "MultiBinary")
3131
{
32-
// num_outputs = action_space.shape[0];
33-
// self.dist = Bernoulli(self.base.output_size, num_outputs)
32+
output_layer = std::make_shared<BernoulliOutput>(
33+
base->get_output_size(), num_outputs);
3434
}
3535
else
3636
{

0 commit comments

Comments
 (0)