@@ -20,54 +20,47 @@ class normal_lowrank : public base_family {
2020 const int dimension_;
2121 const int rank_;
2222
23- void validate_mean (const char * function,
24- const Eigen::VectorXd& mu) {
23+ void validate_mean (const char * function, const Eigen::VectorXd& mu) {
2524 stan::math::check_not_nan (function, " Mean vector" , mu);
26- stan::math::check_size_match (function,
27- " Dimension of input vector" , mu. size () ,
28- " Dimension of current vector " , dimension ());
25+ stan::math::check_size_match (function, " Dimension of input vector " ,
26+ mu. size (), " Dimension of current vector" ,
27+ dimension ());
2928 }
3029
31- void validate_factor (const char * function,
32- const Eigen::MatrixXd& B) {
30+ void validate_factor (const char * function, const Eigen::MatrixXd& B) {
3331 stan::math::check_not_nan (function, " Low rank factor" , B);
34- stan::math::check_size_match (function,
35- " Dimension of mean vector" , dimension (),
36- " Dimension of low-rank factor" , B.rows ());
37- stan::math::check_size_match (function,
38- " Rank of factor" , B.cols (),
32+ stan::math::check_size_match (function, " Dimension of mean vector" ,
33+ dimension (), " Dimension of low-rank factor" ,
34+ B.rows ());
35+ stan::math::check_size_match (function, " Rank of factor" , B.cols (),
3936 " Rank of approximation" , rank ());
4037 }
4138
42- void validate_noise (const char *function,
43- const Eigen::VectorXd& log_d) {
39+ void validate_noise (const char * function, const Eigen::VectorXd& log_d) {
4440 stan::math::check_not_nan (function, " log std vector" , log_d);
45- stan::math::check_size_match (function,
46- " Dimension of mean vector" , dimension () ,
47- " Dimension of log std vector " , log_d.size ());
41+ stan::math::check_size_match (function, " Dimension of mean vector " ,
42+ dimension (), " Dimension of log std vector" ,
43+ log_d.size ());
4844 }
4945
5046 public:
5147 explicit normal_lowrank (const Eigen::VectorXd& mu, size_t rank)
52- : mu_(mu),
53- B_(Eigen::MatrixXd::Zero(mu.size(), rank)),
54- log_d_(Eigen::VectorXd::Zero(mu.size())),
55- dimension_(mu.size()),
56- rank_(rank) {
57- }
48+ : mu_(mu),
49+ B_(Eigen::MatrixXd::Zero(mu.size(), rank)),
50+ log_d_(Eigen::VectorXd::Zero(mu.size())),
51+ dimension_(mu.size()),
52+ rank_(rank) {}
5853
5954 explicit normal_lowrank (size_t dimension, size_t rank)
60- : mu_(Eigen::VectorXd::Zero(dimension)),
61- B_(Eigen::MatrixXd::Zero(dimension, rank)),
62- log_d_(Eigen::VectorXd::Zero(dimension)),
63- dimension_(dimension),
64- rank_(rank) {
65- }
55+ : mu_(Eigen::VectorXd::Zero(dimension)),
56+ B_(Eigen::MatrixXd::Zero(dimension, rank)),
57+ log_d_(Eigen::VectorXd::Zero(dimension)),
58+ dimension_(dimension),
59+ rank_(rank) {}
6660
67- explicit normal_lowrank (const Eigen::VectorXd& mu,
68- const Eigen::MatrixXd& B,
61+ explicit normal_lowrank (const Eigen::VectorXd& mu, const Eigen::MatrixXd& B,
6962 const Eigen::VectorXd& log_d)
70- : mu_(mu), B_(B), log_d_(log_d), dimension_(mu.size()), rank_(B.cols()) {
63+ : mu_(mu), B_(B), log_d_(log_d), dimension_(mu.size()), rank_(B.cols()) {
7164 static const char * function = " stan::variational::normal_lowrank" ;
7265 validate_mean (function, mu);
7366 validate_factor (function, B);
@@ -123,8 +116,8 @@ class normal_lowrank : public base_family {
123116 = " stan::variational::normal_lowrank::operator=" ;
124117 stan::math::check_size_match (function, " Dimension of lhs" , dimension (),
125118 " Dimension of rhs" , rhs.dimension ());
126- stan::math::check_size_match (function, " Rank of lhs" , rank (),
127- " Rank of rhs " , rhs.rank ());
119+ stan::math::check_size_match (function, " Rank of lhs" , rank (), " Rank of rhs " ,
120+ rhs.rank ());
128121 mu_ = rhs.mu ();
129122 B_ = rhs.B ();
130123 log_d_ = rhs.log_d ();
@@ -136,8 +129,8 @@ class normal_lowrank : public base_family {
136129 = " stan::variational::normal_lowrank::operator+=" ;
137130 stan::math::check_size_match (function, " Dimension of lhs" , dimension (),
138131 " Dimension of rhs" , rhs.dimension ());
139- stan::math::check_size_match (function, " Rank of lhs" , rank (),
140- " Rank of rhs " , rhs.rank ());
132+ stan::math::check_size_match (function, " Rank of lhs" , rank (), " Rank of rhs " ,
133+ rhs.rank ());
141134 mu_ += rhs.mu ();
142135 B_ += rhs.B ();
143136 log_d_ += rhs.log_d ();
@@ -150,8 +143,8 @@ class normal_lowrank : public base_family {
150143
151144 stan::math::check_size_match (function, " Dimension of lhs" , dimension (),
152145 " Dimension of rhs" , rhs.dimension ());
153- stan::math::check_size_match (function, " Rank of lhs" , rank (),
154- " Rank of rhs " , rhs.rank ());
146+ stan::math::check_size_match (function, " Rank of lhs" , rank (), " Rank of rhs " ,
147+ rhs.rank ());
155148 mu_.array () /= rhs.mu ().array ();
156149 B_.array () /= rhs.B ().array ();
157150 log_d_.array () /= rhs.log_d ().array ();
@@ -179,24 +172,29 @@ class normal_lowrank : public base_family {
179172 // Determinant by the matrix determinant lemma
180173 // det(D^2 + B.B^T) = det(I + B^T.D^-2.B) * det(D^2)
181174 // where D^2 is diagonal and so can be computed accordingly
182- result
183- += 0.5 * log (
184- (Eigen::MatrixXd::Identity (r, r) +
185- B_.transpose () *
186- log_d_.array ().exp ().square ().matrix ().asDiagonal ().inverse () *
187- B_).determinant ());
175+ result += 0.5
176+ * log ((Eigen::MatrixXd::Identity (r, r)
177+ + B_.transpose ()
178+ * log_d_.array ()
179+ .exp ()
180+ .square ()
181+ .matrix ()
182+ .asDiagonal ()
183+ .inverse ()
184+ * B_)
185+ .determinant ());
188186 for (int d = 0 ; d < dimension (); ++d) {
189187 result += log_d_ (d);
190188 }
191189 return result;
192190 }
193191
194192 Eigen::VectorXd transform (const Eigen::VectorXd& eta) const {
195- static const char * function =
196- " stan::variational::normal_lowrank::transform" ;
197- stan::math::check_size_match (function,
198- " Dimension of input vector " , eta.size (),
199- " Sum of dimension and rank " , dimension () + rank ());
193+ static const char * function
194+ = " stan::variational::normal_lowrank::transform" ;
195+ stan::math::check_size_match (function, " Dimension of input vector " ,
196+ eta.size (), " Sum of dimension and rank " ,
197+ dimension () + rank ());
200198 stan::math::check_not_nan (function, " Input vector" , eta);
201199 Eigen::VectorXd z = eta.head (rank ());
202200 Eigen::VectorXd eps = eta.tail (dimension ());
@@ -238,14 +236,11 @@ class normal_lowrank : public base_family {
238236 }
239237
240238 template <class M , class BaseRNG >
241- void calc_grad (normal_lowrank& elbo_grad,
242- M& m,
243- Eigen::VectorXd& cont_params,
244- int n_monte_carlo_grad,
245- BaseRNG& rng,
239+ void calc_grad (normal_lowrank& elbo_grad, M& m, Eigen::VectorXd& cont_params,
240+ int n_monte_carlo_grad, BaseRNG& rng,
246241 callbacks::logger& logger) const {
247- static const char * function =
248- " stan::variational::normal_lowrank::calc_grad" ;
242+ static const char * function
243+ = " stan::variational::normal_lowrank::calc_grad" ;
249244
250245 stan::math::check_size_match (function, " Dimension of elbo_grad" ,
251246 elbo_grad.dimension (),
@@ -255,8 +250,8 @@ class normal_lowrank : public base_family {
255250 cont_params.size ());
256251
257252 stan::math::check_size_match (function, " Rank of elbo_grad" ,
258- elbo_grad.rank (),
259- " Rank of variational q " , rank ());
253+ elbo_grad.rank (), " Rank of variational q " ,
254+ rank ());
260255
261256 Eigen::VectorXd mu_grad = Eigen::VectorXd::Zero (dimension ());
262257 Eigen::MatrixXd B_grad = Eigen::MatrixXd::Zero (dimension (), rank ());
@@ -279,7 +274,7 @@ class normal_lowrank : public base_family {
279274
280275 // Naive Monte Carlo integration
281276 static const int n_retries = 10 ;
282- for (int i = 0 , n_monte_carlo_drop = 0 ; i < n_monte_carlo_grad; ) {
277+ for (int i = 0 , n_monte_carlo_drop = 0 ; i < n_monte_carlo_grad;) {
283278 // Draw from standard normal and transform to real-coordinate space
284279 for (int d = 0 ; d < dimension () + rank (); ++d) {
285280 eta (d) = stan::math::normal_rng (0 , 1 , rng);
@@ -313,8 +308,9 @@ class normal_lowrank : public base_family {
313308 const char * name = " The number of dropped evaluations" ;
314309 const char * msg1 = " has reached its maximum amount (" ;
315310 int y = n_retries * n_monte_carlo_grad;
316- const char * msg2 = " ). Your model may be either severely "
317- " ill-conditioned or misspecified." ;
311+ const char * msg2
312+ = " ). Your model may be either severely "
313+ " ill-conditioned or misspecified." ;
318314 stan::math::domain_error (function, name, y, msg1, msg2);
319315 }
320316 }
@@ -337,8 +333,7 @@ class normal_lowrank : public base_family {
337333 }
338334};
339335
340- inline normal_lowrank operator +(normal_lowrank lhs,
341- const normal_lowrank& rhs) {
336+ inline normal_lowrank operator +(normal_lowrank lhs, const normal_lowrank& rhs) {
342337 return lhs += rhs;
343338}
344339
@@ -351,8 +346,7 @@ inline normal_lowrank operator+(normal_lowrank lhs,
351346 * @return Elementwise division of the specified approximations.
352347 * @throw std::domain_error If the dimensionalities do not match.
353348 */
354- inline normal_lowrank operator /(normal_lowrank lhs,
355- const normal_lowrank& rhs) {
349+ inline normal_lowrank operator /(normal_lowrank lhs, const normal_lowrank& rhs) {
356350 return lhs /= rhs;
357351}
358352
0 commit comments