1+ #define _USE_MATH_DEFINES
2+ #include < math.h>
3+ #include < cmath>
4+ #include < limits>
5+
6+ #include < c10/util/ArrayRef.h>
7+ #include < torch/torch.h>
8+
9+ #include " cpprl/distributions/normal.h"
10+ #include " third_party/doctest.h"
11+
12+ namespace cpprl
13+ {
14+ Normal::Normal (const torch::Tensor loc,
15+ const torch::Tensor scale)
16+ {
17+ auto broadcasted_tensors = torch::broadcast_tensors ({loc, scale});
18+ this ->loc = broadcasted_tensors[0 ];
19+ this ->scale = broadcasted_tensors[1 ];
20+ batch_shape = this ->loc .sizes ().vec ();
21+ event_shape = {};
22+ }
23+
24+ torch::Tensor Normal::entropy ()
25+ {
26+ return (0.5 + 0.5 * std::log (2 * M_PI) + torch::log (scale)).sum (-1 );
27+ }
28+
29+ std::vector<int64_t > Normal::extended_shape (c10::ArrayRef<int64_t > sample_shape)
30+ {
31+ std::vector<int64_t > output_shape;
32+ output_shape.insert (output_shape.end (),
33+ sample_shape.begin (),
34+ sample_shape.end ());
35+ output_shape.insert (output_shape.end (),
36+ batch_shape.begin (),
37+ batch_shape.end ());
38+ output_shape.insert (output_shape.end (),
39+ event_shape.begin (),
40+ event_shape.end ());
41+ return output_shape;
42+ }
43+
44+ torch::Tensor Normal::log_prob (torch::Tensor value)
45+ {
46+ auto variance = scale.pow (2 );
47+ auto log_scale = scale.log ();
48+ return (-(value - loc).pow (2 ) /
49+ (2 * variance) -
50+ log_scale -
51+ std::log (std::sqrt (2 * M_PI)));
52+ }
53+
54+ torch::Tensor Normal::sample (c10::ArrayRef<int64_t > sample_shape)
55+ {
56+ auto shape = extended_shape (sample_shape);
57+ auto no_grad_guard = torch::NoGradGuard ();
58+ return torch::normal (loc.expand (shape), scale.expand (shape));
59+ }
60+
61+ TEST_CASE (" Normal" )
62+ {
63+ float locs_array[] = {0 , 1 , 2 , 3 , 4 , 5 };
64+ float scales_array[] = {5 , 4 , 3 , 2 , 1 , 0 };
65+ auto locs = torch::from_blob (locs_array, {2 , 3 });
66+ auto scales = torch::from_blob (scales_array, {2 , 3 });
67+ auto dist = Normal (locs, scales);
68+
69+ SUBCASE (" Sampled tensors have correct shape" )
70+ {
71+ CHECK (dist.sample ().sizes ().vec () == std::vector<int64_t >{2 , 3 });
72+ CHECK (dist.sample ({20 }).sizes ().vec () == std::vector<int64_t >{20 , 2 , 3 });
73+ CHECK (dist.sample ({2 , 20 }).sizes ().vec () == std::vector<int64_t >{2 , 20 , 2 , 3 });
74+ CHECK (dist.sample ({1 , 2 , 3 , 4 , 5 }).sizes ().vec () == std::vector<int64_t >{1 , 2 , 3 , 4 , 5 , 2 , 3 });
75+ }
76+
77+ SUBCASE (" entropy()" )
78+ {
79+ auto entropies = dist.entropy ();
80+
81+ SUBCASE (" Returns correct values" )
82+ {
83+ INFO (" Entropies: \n "
84+ << entropies);
85+
86+ CHECK (entropies[0 ].item ().toDouble () ==
87+ doctest::Approx (8.3512 ).epsilon (1e-3 ));
88+ CHECK (entropies[1 ].item ().toDouble () ==
89+ -std::numeric_limits<float >::infinity ());
90+ }
91+
92+ SUBCASE (" Output tensor is the correct size" )
93+ {
94+ CHECK (entropies.sizes ().vec () == std::vector<int64_t >{2 });
95+ }
96+ }
97+
98+ SUBCASE (" log_prob()" )
99+ {
100+ float actions[2 ][3 ] = {{0 , 1 , 2 },
101+ {0 , 1 , 2 }};
102+ auto actions_tensor = torch::from_blob (actions, {2 , 3 });
103+ auto log_probs = dist.log_prob (actions_tensor);
104+
105+ INFO (log_probs << " \n " );
106+ SUBCASE (" Returns correct values" )
107+ {
108+ CHECK (log_probs[0 ][0 ].item ().toDouble () ==
109+ doctest::Approx (-2.5284 ).epsilon (1e-3 ));
110+ CHECK (log_probs[0 ][1 ].item ().toDouble () ==
111+ doctest::Approx (-2.3052 ).epsilon (1e-3 ));
112+ CHECK (log_probs[0 ][2 ].item ().toDouble () ==
113+ doctest::Approx (-2.0176 ).epsilon (1e-3 ));
114+ CHECK (log_probs[1 ][0 ].item ().toDouble () ==
115+ doctest::Approx (-2.7371 ).epsilon (1e-3 ));
116+ CHECK (log_probs[1 ][1 ].item ().toDouble () ==
117+ doctest::Approx (-5.4189 ).epsilon (1e-3 ));
118+ CHECK (std::isnan (log_probs[1 ][2 ].item ().toDouble ()));
119+ }
120+
121+ SUBCASE (" Output tensor is correct size" )
122+ {
123+ CHECK (log_probs.sizes ().vec () == std::vector<int64_t >{2 , 3 });
124+ }
125+ }
126+ }
127+ }
0 commit comments