1+ #include < ATen/core/Reduction.h>
2+ #include < c10/util/ArrayRef.h>
3+ #include < spdlog/spdlog.h>
4+ #include < torch/torch.h>
5+
6+ #include " cpprl/distributions/bernoulli.h"
7+ #include " third_party/doctest.h"
8+
9+ namespace cpprl
10+ {
11+ Bernoulli::Bernoulli (const torch::Tensor *probs,
12+ const torch::Tensor *logits)
13+ {
14+ if ((probs == nullptr ) == (logits == nullptr ))
15+ {
16+ spdlog::error (" Either probs or logits is required, but not both" );
17+ throw std::exception ();
18+ }
19+
20+ if (probs != nullptr )
21+ {
22+ if (probs->dim () < 1 )
23+ {
24+ throw std::exception ();
25+ }
26+ this ->probs = *probs;
27+ // 1.21e-7 is used as the epsilon to match PyTorch's Python results as closely
28+ // as possible
29+ auto clamped_probs = this ->probs .clamp (1.21e-7 , 1 . - 1.21e-7 );
30+ this ->logits = torch::log (clamped_probs) - torch::log1p (-clamped_probs);
31+ }
32+ else
33+ {
34+ if (logits->dim () < 1 )
35+ {
36+ throw std::exception ();
37+ }
38+ this ->logits = *logits;
39+ this ->probs = torch::sigmoid (*logits);
40+ }
41+
42+ param = probs != nullptr ? *probs : *logits;
43+ batch_shape = param.sizes ().vec ();
44+ }
45+
46+ torch::Tensor Bernoulli::entropy ()
47+ {
48+ return torch::binary_cross_entropy_with_logits (logits, probs, torch::Tensor (), torch::Tensor (), Reduction::None);
49+ }
50+
51+ torch::Tensor Bernoulli::log_prob (torch::Tensor value)
52+ {
53+ auto broadcasted_tensors = torch::broadcast_tensors ({logits, value});
54+ return -torch::binary_cross_entropy_with_logits (broadcasted_tensors[0 ], broadcasted_tensors[1 ], torch::Tensor (), torch::Tensor (), Reduction::None);
55+ }
56+
57+ torch::Tensor Bernoulli::sample (c10::ArrayRef<int64_t > sample_shape)
58+ {
59+ auto ext_sample_shape = extended_shape (sample_shape);
60+ torch::NoGradGuard no_grad_guard;
61+ return torch::bernoulli (probs.expand (ext_sample_shape));
62+ }
63+
64+ TEST_CASE (" Bernoulli" )
65+ {
66+ SUBCASE (" Throws when provided both probs and logits" )
67+ {
68+ auto tensor = torch::Tensor ();
69+ CHECK_THROWS (Bernoulli (&tensor, &tensor));
70+ }
71+
72+ SUBCASE (" Sampled numbers are in the right range" )
73+ {
74+ float probabilities[] = {0.2 , 0.2 , 0.2 , 0.2 , 0.2 };
75+ auto probabilities_tensor = torch::from_blob (probabilities, {5 });
76+ auto dist = Bernoulli (&probabilities_tensor, nullptr );
77+
78+ auto output = dist.sample ({100 });
79+ auto more_than_1 = output > 1 ;
80+ auto less_than_0 = output < 0 ;
81+ CHECK (!more_than_1.any ().item ().toInt ());
82+ CHECK (!less_than_0.any ().item ().toInt ());
83+ }
84+
85+ SUBCASE (" Sampled tensors are of the right shape" )
86+ {
87+ float probabilities[] = {0.2 , 0.2 , 0.2 , 0.2 , 0.2 };
88+ auto probabilities_tensor = torch::from_blob (probabilities, {5 });
89+ auto dist = Bernoulli (&probabilities_tensor, nullptr );
90+
91+ CHECK (dist.sample ({20 }).sizes ().vec () == std::vector<int64_t >{20 , 5 });
92+ CHECK (dist.sample ({2 , 20 }).sizes ().vec () == std::vector<int64_t >{2 , 20 , 5 });
93+ CHECK (dist.sample ({1 , 2 , 3 , 4 }).sizes ().vec () == std::vector<int64_t >{1 , 2 , 3 , 4 , 5 });
94+ }
95+
96+ SUBCASE (" Multi-dimensional input probabilities are handled correctly" )
97+ {
98+ SUBCASE (" Sampled tensors are of the right shape" )
99+ {
100+ float probabilities[2 ][4 ] = {{0.5 , 0.5 , 0.0 , 0.0 },
101+ {0.25 , 0.25 , 0.25 , 0.25 }};
102+ auto probabilities_tensor = torch::from_blob (probabilities, {2 , 4 });
103+ auto dist = Bernoulli (&probabilities_tensor, nullptr );
104+
105+ CHECK (dist.sample ({20 }).sizes ().vec () == std::vector<int64_t >{20 , 2 , 4 });
106+ CHECK (dist.sample ({10 , 5 }).sizes ().vec () == std::vector<int64_t >{10 , 5 , 2 , 4 });
107+ }
108+ }
109+
110+ SUBCASE (" entropy()" )
111+ {
112+ float probabilities[2 ][2 ] = {{0.5 , 0.0 },
113+ {0.25 , 0.25 }};
114+ auto probabilities_tensor = torch::from_blob (probabilities, {2 , 2 });
115+ auto dist = Bernoulli (&probabilities_tensor, nullptr );
116+
117+ auto entropies = dist.entropy ();
118+
119+ SUBCASE (" Returns correct values" )
120+ {
121+ CHECK (entropies[0 ][0 ].item ().toDouble () ==
122+ doctest::Approx (0.6931 ).epsilon (1e-3 ));
123+ CHECK (entropies[0 ][1 ].item ().toDouble () ==
124+ doctest::Approx (0.0000 ).epsilon (1e-3 ));
125+ CHECK (entropies[1 ][0 ].item ().toDouble () ==
126+ doctest::Approx (0.5623 ).epsilon (1e-3 ));
127+ CHECK (entropies[1 ][1 ].item ().toDouble () ==
128+ doctest::Approx (0.5623 ).epsilon (1e-3 ));
129+ }
130+
131+ SUBCASE (" Output tensor is the correct size" )
132+ {
133+ CHECK (entropies.sizes ().vec () == std::vector<int64_t >{2 , 2 });
134+ }
135+ }
136+
137+ SUBCASE (" log_prob()" )
138+ {
139+ float probabilities[2 ][2 ] = {{0.5 , 0.0 },
140+ {0.25 , 0.25 }};
141+ auto probabilities_tensor = torch::from_blob (probabilities, {2 , 2 });
142+ auto dist = Bernoulli (&probabilities_tensor, nullptr );
143+
144+ float actions[2 ][2 ] = {{1 , 0 },
145+ {1 , 0 }};
146+ auto actions_tensor = torch::from_blob (actions, {2 , 2 });
147+ auto log_probs = dist.log_prob (actions_tensor);
148+
149+ INFO (log_probs << " \n " );
150+ SUBCASE (" Returns correct values" )
151+ {
152+ CHECK (log_probs[0 ][0 ].item ().toDouble () ==
153+ doctest::Approx (-0.6931 ).epsilon (1e-3 ));
154+ CHECK (log_probs[0 ][1 ].item ().toDouble () ==
155+ doctest::Approx (0.0000 ).epsilon (1e-3 ));
156+ CHECK (log_probs[1 ][0 ].item ().toDouble () ==
157+ doctest::Approx (-1.3863 ).epsilon (1e-3 ));
158+ CHECK (log_probs[1 ][1 ].item ().toDouble () ==
159+ doctest::Approx (-0.2876 ).epsilon (1e-3 ));
160+ }
161+
162+ SUBCASE (" Output tensor is correct size" )
163+ {
164+ CHECK (log_probs.sizes ().vec () == std::vector<int64_t >{2 , 2 });
165+ }
166+ }
167+ }
168+ }
0 commit comments