Skip to content
This repository was archived by the owner on Dec 28, 2023. It is now read-only.

Commit 6460f0a

Browse files
author
Isaac Poulton
authored
Add Bernoulli distribution (#10)
* Add Bernoulli distribution * Refactor Distribution::extended_shape
1 parent e0f6e67 commit 6460f0a

File tree

10 files changed

+235
-48
lines changed

10 files changed

+235
-48
lines changed

example/gym_client.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ using namespace cpprl;
1515

1616
// Algorithm hyperparameters
1717
const std::string algorithm = "PPO";
18-
const int batch_size = 256;
18+
const int batch_size = 2048;
1919
const float clip_param = 0.2;
2020
const float discount_factor = 0.99;
21-
const float entropy_coef = 0.0;
21+
const float entropy_coef = 0.001;
2222
const float gae = 0.95;
23-
const float learning_rate = 3e-4;
23+
const float learning_rate = 2.5e-4;
2424
const int log_interval = 1;
25-
const int max_frames = 1e+7;
25+
const int max_frames = 10e+7;
2626
const int num_epoch = 10;
27-
const int num_mini_batch = 8;
27+
const int num_mini_batch = 32;
2828
const int reward_average_window_size = 10;
2929
const bool use_gae = true;
3030
const bool use_lr_decay = true;
@@ -33,12 +33,12 @@ const float value_loss_coef = 0.5;
3333
// Environment hyperparameters
3434
const float env_gamma = discount_factor; // Set to -1 to disable
3535
const std::string env_name = "BipedalWalkerHardcore-v2";
36-
const int num_envs = 8;
37-
const float render_reward_threshold = 250;
36+
const int num_envs = 16;
37+
const float render_reward_threshold = 160;
3838

3939
// Model hyperparameters
4040
const int hidden_size = 64;
41-
const bool recurrent = true;
41+
const bool recurrent = false;
4242
const bool use_cuda = false;
4343

4444
std::vector<float> flatten_vector(std::vector<float> const &input)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <c10/util/ArrayRef.h>
4+
#include <torch/torch.h>
5+
6+
#include "cpprl/distributions/distribution.h"
7+
8+
namespace cpprl
9+
{
10+
class Bernoulli : public Distribution
11+
{
12+
private:
13+
torch::Tensor probs, logits, param;
14+
15+
public:
16+
Bernoulli(const torch::Tensor *probs, const torch::Tensor *logits);
17+
18+
torch::Tensor entropy();
19+
torch::Tensor log_prob(torch::Tensor value);
20+
torch::Tensor sample(c10::ArrayRef<int64_t> sample_shape = {});
21+
22+
inline torch::Tensor get_logits() { return logits; }
23+
inline torch::Tensor get_probs() { return probs; }
24+
};
25+
}

include/cpprl/distributions/categorical.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,9 @@ namespace cpprl
1010
class Categorical : public Distribution
1111
{
1212
private:
13-
torch::Tensor probs;
14-
torch::Tensor logits;
15-
std::vector<int64_t> batch_shape;
16-
std::vector<int64_t> event_shape;
17-
torch::Tensor param;
13+
torch::Tensor probs, logits, param;
1814
int num_events;
1915

20-
std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);
21-
2216
public:
2317
Categorical(const torch::Tensor *probs, const torch::Tensor *logits);
2418

include/cpprl/distributions/distribution.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ namespace cpprl
66
{
77
class Distribution
88
{
9+
protected:
10+
std::vector<int64_t> batch_shape, event_shape;
11+
12+
std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);
13+
914
public:
1015
virtual ~Distribution() = 0;
1116

include/cpprl/distributions/normal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ class Normal : public Distribution
1111
{
1212
private:
1313
torch::Tensor loc, scale;
14-
std::vector<int64_t> batch_shape, event_shape;
15-
16-
std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);
1714

1815
public:
1916
Normal(const torch::Tensor loc, const torch::Tensor scale);

src/distributions/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
target_sources(cpprl
22
PRIVATE
3+
${CMAKE_CURRENT_LIST_DIR}/bernoulli.cpp
34
${CMAKE_CURRENT_LIST_DIR}/categorical.cpp
5+
${CMAKE_CURRENT_LIST_DIR}/distribution.cpp
46
${CMAKE_CURRENT_LIST_DIR}/normal.cpp
57
)
68

79
if (CPPRL_BUILD_TESTS)
810
target_sources(cpprl_tests
911
PRIVATE
12+
${CMAKE_CURRENT_LIST_DIR}/bernoulli.cpp
1013
${CMAKE_CURRENT_LIST_DIR}/categorical.cpp
14+
${CMAKE_CURRENT_LIST_DIR}/distribution.cpp
1115
${CMAKE_CURRENT_LIST_DIR}/normal.cpp
1216
)
1317
endif (CPPRL_BUILD_TESTS)

src/distributions/bernoulli.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#include <ATen/core/Reduction.h>
2+
#include <c10/util/ArrayRef.h>
3+
#include <spdlog/spdlog.h>
4+
#include <torch/torch.h>
5+
6+
#include "cpprl/distributions/bernoulli.h"
7+
#include "third_party/doctest.h"
8+
9+
namespace cpprl
10+
{
11+
Bernoulli::Bernoulli(const torch::Tensor *probs,
12+
const torch::Tensor *logits)
13+
{
14+
if ((probs == nullptr) == (logits == nullptr))
15+
{
16+
spdlog::error("Either probs or logits is required, but not both");
17+
throw std::exception();
18+
}
19+
20+
if (probs != nullptr)
21+
{
22+
if (probs->dim() < 1)
23+
{
24+
throw std::exception();
25+
}
26+
this->probs = *probs;
27+
// 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
28+
// as possible
29+
auto clamped_probs = this->probs.clamp(1.21e-7, 1. - 1.21e-7);
30+
this->logits = torch::log(clamped_probs) - torch::log1p(-clamped_probs);
31+
}
32+
else
33+
{
34+
if (logits->dim() < 1)
35+
{
36+
throw std::exception();
37+
}
38+
this->logits = *logits;
39+
this->probs = torch::sigmoid(*logits);
40+
}
41+
42+
param = probs != nullptr ? *probs : *logits;
43+
batch_shape = param.sizes().vec();
44+
}
45+
46+
torch::Tensor Bernoulli::entropy()
47+
{
48+
return torch::binary_cross_entropy_with_logits(logits, probs, torch::Tensor(), torch::Tensor(), Reduction::None);
49+
}
50+
51+
torch::Tensor Bernoulli::log_prob(torch::Tensor value)
52+
{
53+
auto broadcasted_tensors = torch::broadcast_tensors({logits, value});
54+
return -torch::binary_cross_entropy_with_logits(broadcasted_tensors[0], broadcasted_tensors[1], torch::Tensor(), torch::Tensor(), Reduction::None);
55+
}
56+
57+
torch::Tensor Bernoulli::sample(c10::ArrayRef<int64_t> sample_shape)
58+
{
59+
auto ext_sample_shape = extended_shape(sample_shape);
60+
torch::NoGradGuard no_grad_guard;
61+
return torch::bernoulli(probs.expand(ext_sample_shape));
62+
}
63+
64+
TEST_CASE("Bernoulli")
65+
{
66+
SUBCASE("Throws when provided both probs and logits")
67+
{
68+
auto tensor = torch::Tensor();
69+
CHECK_THROWS(Bernoulli(&tensor, &tensor));
70+
}
71+
72+
SUBCASE("Sampled numbers are in the right range")
73+
{
74+
float probabilities[] = {0.2, 0.2, 0.2, 0.2, 0.2};
75+
auto probabilities_tensor = torch::from_blob(probabilities, {5});
76+
auto dist = Bernoulli(&probabilities_tensor, nullptr);
77+
78+
auto output = dist.sample({100});
79+
auto more_than_1 = output > 1;
80+
auto less_than_0 = output < 0;
81+
CHECK(!more_than_1.any().item().toInt());
82+
CHECK(!less_than_0.any().item().toInt());
83+
}
84+
85+
SUBCASE("Sampled tensors are of the right shape")
86+
{
87+
float probabilities[] = {0.2, 0.2, 0.2, 0.2, 0.2};
88+
auto probabilities_tensor = torch::from_blob(probabilities, {5});
89+
auto dist = Bernoulli(&probabilities_tensor, nullptr);
90+
91+
CHECK(dist.sample({20}).sizes().vec() == std::vector<int64_t>{20, 5});
92+
CHECK(dist.sample({2, 20}).sizes().vec() == std::vector<int64_t>{2, 20, 5});
93+
CHECK(dist.sample({1, 2, 3, 4}).sizes().vec() == std::vector<int64_t>{1, 2, 3, 4, 5});
94+
}
95+
96+
SUBCASE("Multi-dimensional input probabilities are handled correctly")
97+
{
98+
SUBCASE("Sampled tensors are of the right shape")
99+
{
100+
float probabilities[2][4] = {{0.5, 0.5, 0.0, 0.0},
101+
{0.25, 0.25, 0.25, 0.25}};
102+
auto probabilities_tensor = torch::from_blob(probabilities, {2, 4});
103+
auto dist = Bernoulli(&probabilities_tensor, nullptr);
104+
105+
CHECK(dist.sample({20}).sizes().vec() == std::vector<int64_t>{20, 2, 4});
106+
CHECK(dist.sample({10, 5}).sizes().vec() == std::vector<int64_t>{10, 5, 2, 4});
107+
}
108+
}
109+
110+
SUBCASE("entropy()")
111+
{
112+
float probabilities[2][2] = {{0.5, 0.0},
113+
{0.25, 0.25}};
114+
auto probabilities_tensor = torch::from_blob(probabilities, {2, 2});
115+
auto dist = Bernoulli(&probabilities_tensor, nullptr);
116+
117+
auto entropies = dist.entropy();
118+
119+
SUBCASE("Returns correct values")
120+
{
121+
CHECK(entropies[0][0].item().toDouble() ==
122+
doctest::Approx(0.6931).epsilon(1e-3));
123+
CHECK(entropies[0][1].item().toDouble() ==
124+
doctest::Approx(0.0000).epsilon(1e-3));
125+
CHECK(entropies[1][0].item().toDouble() ==
126+
doctest::Approx(0.5623).epsilon(1e-3));
127+
CHECK(entropies[1][1].item().toDouble() ==
128+
doctest::Approx(0.5623).epsilon(1e-3));
129+
}
130+
131+
SUBCASE("Output tensor is the correct size")
132+
{
133+
CHECK(entropies.sizes().vec() == std::vector<int64_t>{2, 2});
134+
}
135+
}
136+
137+
SUBCASE("log_prob()")
138+
{
139+
float probabilities[2][2] = {{0.5, 0.0},
140+
{0.25, 0.25}};
141+
auto probabilities_tensor = torch::from_blob(probabilities, {2, 2});
142+
auto dist = Bernoulli(&probabilities_tensor, nullptr);
143+
144+
float actions[2][2] = {{1, 0},
145+
{1, 0}};
146+
auto actions_tensor = torch::from_blob(actions, {2, 2});
147+
auto log_probs = dist.log_prob(actions_tensor);
148+
149+
INFO(log_probs << "\n");
150+
SUBCASE("Returns correct values")
151+
{
152+
CHECK(log_probs[0][0].item().toDouble() ==
153+
doctest::Approx(-0.6931).epsilon(1e-3));
154+
CHECK(log_probs[0][1].item().toDouble() ==
155+
doctest::Approx(0.0000).epsilon(1e-3));
156+
CHECK(log_probs[1][0].item().toDouble() ==
157+
doctest::Approx(-1.3863).epsilon(1e-3));
158+
CHECK(log_probs[1][1].item().toDouble() ==
159+
doctest::Approx(-0.2876).epsilon(1e-3));
160+
}
161+
162+
SUBCASE("Output tensor is correct size")
163+
{
164+
CHECK(log_probs.sizes().vec() == std::vector<int64_t>{2, 2});
165+
}
166+
}
167+
}
168+
}

src/distributions/categorical.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <c10/util/ArrayRef.h>
2+
#include <spdlog/spdlog.h>
23
#include <torch/torch.h>
34

45
#include "cpprl/distributions/categorical.h"
@@ -11,6 +12,7 @@ Categorical::Categorical(const torch::Tensor *probs,
1112
{
1213
if ((probs == nullptr) == (logits == nullptr))
1314
{
15+
spdlog::error("Either probs or logits is required, but not both");
1416
throw std::exception();
1517
}
1618

@@ -51,21 +53,6 @@ torch::Tensor Categorical::entropy()
5153
return -p_log_p.sum(-1);
5254
}
5355

54-
std::vector<int64_t> Categorical::extended_shape(c10::ArrayRef<int64_t> sample_shape)
55-
{
56-
std::vector<int64_t> output_shape;
57-
output_shape.insert(output_shape.end(),
58-
sample_shape.begin(),
59-
sample_shape.end());
60-
output_shape.insert(output_shape.end(),
61-
batch_shape.begin(),
62-
batch_shape.end());
63-
output_shape.insert(output_shape.end(),
64-
event_shape.begin(),
65-
event_shape.end());
66-
return output_shape;
67-
}
68-
6956
torch::Tensor Categorical::log_prob(torch::Tensor value)
7057
{
7158
value = value.to(torch::kLong).unsqueeze(-1);

src/distributions/distribution.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <vector>
2+
#include <ctype.h>
3+
4+
#include "cpprl/distributions/distribution.h"
5+
6+
namespace cpprl
7+
{
8+
std::vector<int64_t> Distribution::extended_shape(c10::ArrayRef<int64_t> sample_shape)
9+
{
10+
std::vector<int64_t> output_shape;
11+
output_shape.insert(output_shape.end(),
12+
sample_shape.begin(),
13+
sample_shape.end());
14+
output_shape.insert(output_shape.end(),
15+
batch_shape.begin(),
16+
batch_shape.end());
17+
output_shape.insert(output_shape.end(),
18+
event_shape.begin(),
19+
event_shape.end());
20+
return output_shape;
21+
}
22+
}

src/distributions/normal.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,6 @@ torch::Tensor Normal::entropy()
2626
return (0.5 + 0.5 * std::log(2 * M_PI) + torch::log(scale)).sum(-1);
2727
}
2828

29-
std::vector<int64_t> Normal::extended_shape(c10::ArrayRef<int64_t> sample_shape)
30-
{
31-
std::vector<int64_t> output_shape;
32-
output_shape.insert(output_shape.end(),
33-
sample_shape.begin(),
34-
sample_shape.end());
35-
output_shape.insert(output_shape.end(),
36-
batch_shape.begin(),
37-
batch_shape.end());
38-
output_shape.insert(output_shape.end(),
39-
event_shape.begin(),
40-
event_shape.end());
41-
return output_shape;
42-
}
43-
4429
torch::Tensor Normal::log_prob(torch::Tensor value)
4530
{
4631
auto variance = scale.pow(2);

0 commit comments

Comments
 (0)