Skip to content

Commit 5165252

Browse files
authored
Merge pull request #2979 from stan-dev/feature/varmat-reader-simple
allow reader to read out var<matrix> types
2 parents fdf7441 + ef5898b commit 5165252

File tree

2 files changed

+205
-7
lines changed

2 files changed

+205
-7
lines changed

src/stan/io/reader.hpp

Lines changed: 115 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef STAN_IO_READER_HPP
22
#define STAN_IO_READER_HPP
33

4-
#include <stan/math/prim.hpp>
4+
#include <stan/math/rev.hpp>
55
#include <stdexcept>
66
#include <string>
77
#include <vector>
@@ -54,13 +54,20 @@ class reader {
5454
}
5555

5656
public:
57-
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> matrix_t;
58-
typedef Eigen::Matrix<T, Eigen::Dynamic, 1> vector_t;
59-
typedef Eigen::Matrix<T, 1, Eigen::Dynamic> row_vector_t;
57+
using matrix_t = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
58+
using vector_t = Eigen::Matrix<T, Eigen::Dynamic, 1>;
59+
using row_vector_t = Eigen::Matrix<T, 1, Eigen::Dynamic>;
6060

61-
typedef Eigen::Map<matrix_t> map_matrix_t;
62-
typedef Eigen::Map<vector_t> map_vector_t;
63-
typedef Eigen::Map<row_vector_t> map_row_vector_t;
61+
using map_matrix_t = Eigen::Map<matrix_t>;
62+
using map_vector_t = Eigen::Map<vector_t>;
63+
using map_row_vector_t = Eigen::Map<row_vector_t>;
64+
65+
using var_matrix_t = stan::math::var_value<
66+
Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>;
67+
using var_vector_t
68+
= stan::math::var_value<Eigen::Matrix<double, Eigen::Dynamic, 1>>;
69+
using var_row_vector_t
70+
= stan::math::var_value<Eigen::Matrix<double, 1, Eigen::Dynamic>>;
6471

6572
/**
6673
* Construct a variable reader using the specified vectors
@@ -186,6 +193,33 @@ class reader {
186193
return vector_t();
187194
return map_vector_t(&scalar_ptr_increment(m), m);
188195
}
196+
197+
/**
198+
* Return a `var_value` with inner type column vector with specified
199+
* dimensionality made up of the next scalars.
200+
*
201+
* @param m Number of rows in the vector to read.
202+
* @return Column vector made up of the next scalars.
203+
*/
204+
template <typename T_ = T, require_st_var<T_> * = nullptr>
205+
inline var_vector_t var_vector(size_t m) {
206+
if (m == 0)
207+
return var_vector_t(Eigen::VectorXd(0));
208+
return stan::math::to_var_value(map_vector_t(&scalar_ptr_increment(m), m));
209+
}
210+
211+
/**
212+
* Return a column vector of specified dimensionality made up of
213+
* the next scalars.
214+
*
215+
* @param m Number of rows in the vector to read.
216+
* @return Column vector made up of the next scalars.
217+
*/
218+
template <typename T_ = T, require_st_arithmetic<T_> * = nullptr>
219+
inline vector_t var_vector(size_t m) {
220+
return this->vector(m);
221+
}
222+
189223
/**
190224
* Return a column vector of specified dimensionality made up of
191225
* the next scalars. The constraint is a no-op.
@@ -225,6 +259,33 @@ class reader {
225259
return map_row_vector_t(&scalar_ptr_increment(m), m);
226260
}
227261

262+
/**
263+
* Return a `var_value` with inner type as a row vector with specified
264+
* dimensionality made up of the next scalars.
265+
*
266+
* @param m Number of rows in the vector to read.
267+
* @return Column vector made up of the next scalars.
268+
*/
269+
template <typename T_ = T, require_st_var<T_> * = nullptr>
270+
inline var_row_vector_t var_row_vector(size_t m) {
271+
if (m == 0)
272+
return var_row_vector_t(Eigen::RowVectorXd(0));
273+
return stan::math::to_var_value(
274+
map_row_vector_t(&scalar_ptr_increment(m), m));
275+
}
276+
277+
/**
278+
* Return a row vector of specified dimensionality made up of
279+
* the next scalars.
280+
*
281+
* @param m Number of rows in the vector to read.
282+
* @return Column vector made up of the next scalars.
283+
*/
284+
template <typename T_ = T, require_st_arithmetic<T_> * = nullptr>
285+
inline row_vector_t var_row_vector(size_t m) {
286+
return this->row_vector(m);
287+
}
288+
228289
/**
229290
* Return a row vector of specified dimensionality made up of
230291
* the next scalars. The constraint is a no-op.
@@ -276,6 +337,53 @@ class reader {
276337
return map_matrix_t(&scalar_ptr_increment(m * n), m, n);
277338
}
278339

340+
/**
341+
* Return a `var_value` with inner type matrix with the specified
342+
* dimensionality made up of the next scalars arranged in column-major order.
343+
*
344+
* Row-major reading means that if a matrix of <code>m=2</code>
345+
* rows and <code>n=3</code> columns is read and the next
346+
* scalar values are <code>1,2,3,4,5,6</code>, the result is
347+
*
348+
* <pre>
349+
* a = 1 4
350+
* 2 5
351+
* 3 6</pre>
352+
*
353+
* @param m Number of rows.
354+
* @param n Number of columns.
355+
* @return Eigen::Matrix made up of the next scalars.
356+
*/
357+
template <typename T_ = T, require_st_var<T_> * = nullptr>
358+
inline var_matrix_t var_matrix(size_t m, size_t n) {
359+
if (m == 0 || n == 0)
360+
return var_matrix_t(Eigen::MatrixXd(0, 0));
361+
return stan::math::to_var_value(
362+
map_matrix_t(&scalar_ptr_increment(m * n), m, n));
363+
}
364+
365+
/**
366+
* Return a matrix of the specified dimensionality made up of
367+
* the next scalars arranged in column-major order.
368+
*
369+
* Row-major reading means that if a matrix of <code>m=2</code>
370+
* rows and <code>n=3</code> columns is read and the next
371+
* scalar values are <code>1,2,3,4,5,6</code>, the result is
372+
*
373+
* <pre>
374+
* a = 1 4
375+
* 2 5
376+
* 3 6</pre>
377+
*
378+
* @param m Number of rows.
379+
* @param n Number of columns.
380+
* @return Eigen::Matrix made up of the next scalars.
381+
*/
382+
template <typename T_ = T, require_st_arithmetic<T_> * = nullptr>
383+
inline matrix_t var_matrix(size_t m, size_t n) {
384+
return this->matrix(m, n);
385+
}
386+
279387
/**
280388
* Return a matrix of the specified dimensionality made up of
281389
* the next scalars arranged in column-major order. The

src/test/unit/io/reader_test.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,3 +1429,93 @@ TEST(IoReader, UnitVectorThrows) {
14291429
EXPECT_THROW(reader.unit_vector_constrain(x), std::invalid_argument);
14301430
EXPECT_THROW(reader.unit_vector_constrain(x, lp), std::invalid_argument);
14311431
}
1432+
1433+
TEST(IoReader, var_vector) {
1434+
using stan::math::var;
1435+
using stan::math::var_value;
1436+
std::vector<var> theta{0, 1, 2, 3, 4};
1437+
std::vector<int> theta_i;
1438+
stan::io::reader<stan::math::var> reader(theta, theta_i);
1439+
auto vec_x = reader.var_vector(5);
1440+
EXPECT_TRUE((stan::is_var_vector<decltype(vec_x)>::value));
1441+
for (int i = 0; i < 5; ++i) {
1442+
EXPECT_EQ(vec_x.val()(i), i);
1443+
}
1444+
auto vec_x_empty = reader.var_vector(0);
1445+
stan::math::recover_memory();
1446+
}
1447+
1448+
TEST(IoReader, var_vector_double) {
1449+
using stan::math::var;
1450+
using stan::math::var_value;
1451+
std::vector<double> theta{0, 1, 2, 3, 4};
1452+
std::vector<int> theta_i;
1453+
stan::io::reader<double> reader(theta, theta_i);
1454+
auto vec_x = reader.var_vector(5);
1455+
EXPECT_TRUE(
1456+
(stan::is_eigen_vector<decltype(vec_x)>::value
1457+
&& std::is_arithmetic<stan::value_type_t<decltype(vec_x)>>::value));
1458+
for (int i = 0; i < 5; ++i) {
1459+
EXPECT_EQ(vec_x.val()(i), i);
1460+
}
1461+
}
1462+
1463+
TEST(IoReader, var_row_vector) {
1464+
using stan::math::var;
1465+
using stan::math::var_value;
1466+
std::vector<var> theta{0, 1, 2, 3, 4};
1467+
std::vector<int> theta_i;
1468+
stan::io::reader<stan::math::var> reader(theta, theta_i);
1469+
auto vec_x = reader.var_row_vector(5);
1470+
EXPECT_TRUE((stan::is_var_row_vector<decltype(vec_x)>::value));
1471+
for (int i = 0; i < 5; ++i) {
1472+
EXPECT_EQ(vec_x.val()(i), i);
1473+
}
1474+
auto vec_x_empty = reader.var_row_vector(0);
1475+
}
1476+
1477+
TEST(IoReader, var_row_vector_double) {
1478+
using stan::math::var;
1479+
using stan::math::var_value;
1480+
std::vector<double> theta{0, 1, 2, 3, 4};
1481+
std::vector<int> theta_i;
1482+
stan::io::reader<double> reader(theta, theta_i);
1483+
auto vec_x = reader.var_row_vector(5);
1484+
EXPECT_TRUE(
1485+
(stan::is_eigen_row_vector<decltype(vec_x)>::value
1486+
&& std::is_arithmetic<stan::value_type_t<decltype(vec_x)>>::value));
1487+
for (int i = 0; i < 5; ++i) {
1488+
EXPECT_EQ(vec_x.val()(i), i);
1489+
}
1490+
auto vec_x_empty = reader.var_row_vector(0);
1491+
}
1492+
1493+
TEST(IoReader, var_matrix) {
1494+
using stan::math::var;
1495+
using stan::math::var_value;
1496+
std::vector<var> theta{0, 1, 2, 3, 4, 5, 6, 7, 8};
1497+
std::vector<int> theta_i;
1498+
stan::io::reader<stan::math::var> reader(theta, theta_i);
1499+
auto mat_x = reader.var_matrix(3, 3);
1500+
EXPECT_TRUE((stan::is_var_matrix<decltype(mat_x)>::value));
1501+
for (int i = 0; i < 9; ++i) {
1502+
EXPECT_EQ(mat_x.val()(i), i);
1503+
}
1504+
auto mat_x_empty = reader.var_matrix(0, 0);
1505+
}
1506+
1507+
TEST(IoReader, var_matrix_double) {
1508+
using stan::math::var;
1509+
using stan::math::var_value;
1510+
std::vector<double> theta{0, 1, 2, 3, 4, 5, 6, 7, 8};
1511+
std::vector<int> theta_i;
1512+
stan::io::reader<double> reader(theta, theta_i);
1513+
auto mat_x = reader.var_matrix(3, 3);
1514+
EXPECT_TRUE(
1515+
(stan::is_eigen_dense_dynamic<decltype(mat_x)>::value
1516+
&& std::is_arithmetic<stan::value_type_t<decltype(mat_x)>>::value));
1517+
for (int i = 0; i < 9; ++i) {
1518+
EXPECT_EQ(mat_x.val()(i), i);
1519+
}
1520+
auto mat_x_empty = reader.var_matrix(0, 0);
1521+
}

0 commit comments

Comments
 (0)