Skip to content

Commit ffcc9b2

Browse files
authored
Merge pull request #2814 from stan-dev/feature/2704-schur-decomposition
Feature/2704 Add complex Schur decomposition
2 parents 15f1348 + 4430597 commit ffcc9b2

File tree

6 files changed

+206
-19
lines changed

6 files changed

+206
-19
lines changed

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include <stan/math/prim/fun/cols.hpp>
4747
#include <stan/math/prim/fun/columns_dot_product.hpp>
4848
#include <stan/math/prim/fun/columns_dot_self.hpp>
49+
#include <stan/math/prim/fun/complex_schur_decompose.hpp>
4950
#include <stan/math/prim/fun/conj.hpp>
5051
#include <stan/math/prim/fun/constants.hpp>
5152
#include <stan/math/prim/fun/copysign.hpp>
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#ifndef STAN_MATH_PRIM_FUN_COMPLEX_SCHUR_DECOMPOSE_HPP
2+
#define STAN_MATH_PRIM_FUN_COMPLEX_SCHUR_DECOMPOSE_HPP
3+
4+
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/err.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return the unitary matrix of the complex Schur decomposition of the
13+
* specified square matrix.
14+
*
15+
* The complex Schur decomposition of a square matrix `A` produces a
16+
* complex unitary matrix `U` and a complex upper-triangular Schur
17+
* form matrix `T` such that `A = U * T * inv(U)`. Further, the
18+
* unitary matrix's inverse is equal to its conjugate transpose,
19+
* `inv(U) = U*`, where `U*(i, j) = conj(U(j, i))`
20+
*
21+
* @tparam M type of matrix
22+
* @param m real matrix to decompose
23+
* @return complex unitary matrix of the complex Schur decomposition of the
24+
* specified matrix
25+
* @see complex_schur_decompose_t
26+
*/
27+
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
28+
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
29+
complex_schur_decompose_u(const M& m) {
30+
if (m.size() == 0)
31+
return m;
32+
check_square("complex_schur_decompose_u", "m", m);
33+
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
34+
// copy because ComplexSchur requires Eigen::Matrix type
35+
MatType mv = m;
36+
Eigen::ComplexSchur<MatType> cs(mv);
37+
return cs.matrixU();
38+
}
39+
40+
/**
41+
* Return the Schur form matrix of the complex Schur decomposition of the
42+
* specified square matrix.
43+
*
44+
* @tparam M type of matrix
45+
* @param m real matrix to decompose
46+
* @return Schur form matrix of the complex Schur decomposition of the
47+
* specified matrix
48+
* @see complex_schur_decompose_u for a definition of the complex
49+
* Schur decomposition
50+
*/
51+
template <typename M, require_eigen_dense_dynamic_t<M>* = nullptr>
52+
inline Eigen::Matrix<complex_return_t<scalar_type_t<M>>, -1, -1>
53+
complex_schur_decompose_t(const M& m) {
54+
if (m.size() == 0)
55+
return m;
56+
check_square("complex_schur_decompose_t", "m", m);
57+
using MatType = Eigen::Matrix<scalar_type_t<M>, -1, -1>;
58+
// copy because ComplexSchur requires Eigen::Matrix type
59+
MatType mv = m;
60+
Eigen::ComplexSchur<MatType> cs(mv, false);
61+
return cs.matrixT();
62+
}
63+
64+
} // namespace math
65+
} // namespace stan
66+
#endif
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <stdexcept>
3+
4+
TEST(mathMixFun, complexSchurDecomposeT) {
5+
auto f = [](const auto& x) {
6+
using stan::math::complex_schur_decompose_t;
7+
return complex_schur_decompose_t(x);
8+
};
9+
for (const auto& x : stan::test::square_test_matrices(0, 2)) {
10+
stan::test::expect_ad(f, x);
11+
}
12+
13+
Eigen::MatrixXd a32(3, 2);
14+
a32 << 3, -5, 7, -7.2, 9.1, -6.3;
15+
EXPECT_THROW(f(a32), std::invalid_argument);
16+
}
17+
18+
TEST(mathMixFun, complexSchurDecomposeU) {
19+
auto f = [](const auto& x) {
20+
using stan::math::complex_schur_decompose_u;
21+
return complex_schur_decompose_u(x);
22+
};
23+
for (const auto& x : stan::test::square_test_matrices(0, 2)) {
24+
stan::test::expect_ad(f, x);
25+
}
26+
27+
Eigen::MatrixXd a32(3, 2);
28+
a32 << 3, -5, 7, -7.2, 9.1, -6.3;
29+
EXPECT_THROW(f(a32), std::invalid_argument);
30+
}
31+
32+
template <typename V>
33+
void test_complex_schur_decompose(const Eigen::MatrixXd& x) {
34+
using stan::math::complex_schur_decompose_t;
35+
using stan::math::complex_schur_decompose_u;
36+
using stan::math::get_real;
37+
using stan::math::value_of_rec;
38+
Eigen::Matrix<V, -1, -1> X = x;
39+
40+
auto T = complex_schur_decompose_t(X);
41+
auto U = complex_schur_decompose_u(X);
42+
auto X2 = U * T * U.adjoint();
43+
44+
EXPECT_MATRIX_NEAR(x, value_of_rec(get_real(X2)), 1e-8);
45+
}
46+
47+
template <typename V>
48+
void test_complex_schur_decompose_complex(const Eigen::MatrixXd& x) {
49+
using stan::math::complex_schur_decompose_t;
50+
using stan::math::complex_schur_decompose_u;
51+
using stan::math::value_of_rec;
52+
Eigen::Matrix<std::complex<V>, -1, -1> X(x.rows(), x.cols());
53+
for (int i = 0; i < x.size(); ++i)
54+
X(i) = std::complex<double>(x(i), i);
55+
auto T = complex_schur_decompose_t(X);
56+
auto U = complex_schur_decompose_u(X);
57+
auto X2 = U * T * U.adjoint();
58+
59+
EXPECT_MATRIX_COMPLEX_NEAR(value_of_rec(X), value_of_rec(X2), 1e-8);
60+
}
61+
62+
TEST(mathMixFun, complexSchurDecompose) {
63+
using d_t = double;
64+
using v_t = stan::math::var;
65+
using fd_t = stan::math::fvar<d_t>;
66+
using ffd_t = stan::math::fvar<fd_t>;
67+
using fv_t = stan::math::fvar<v_t>;
68+
using ffv_t = stan::math::fvar<fv_t>;
69+
for (const auto& x : stan::test::square_test_matrices(0, 3)) {
70+
test_complex_schur_decompose<d_t>(x);
71+
test_complex_schur_decompose<v_t>(x);
72+
test_complex_schur_decompose<fd_t>(x);
73+
test_complex_schur_decompose<ffd_t>(x);
74+
test_complex_schur_decompose<fv_t>(x);
75+
test_complex_schur_decompose<ffv_t>(x);
76+
}
77+
for (const auto& x : stan::test::square_test_matrices(0, 3)) {
78+
test_complex_schur_decompose_complex<d_t>(x);
79+
test_complex_schur_decompose_complex<v_t>(x);
80+
test_complex_schur_decompose_complex<fd_t>(x);
81+
test_complex_schur_decompose_complex<ffd_t>(x);
82+
test_complex_schur_decompose_complex<fv_t>(x);
83+
test_complex_schur_decompose_complex<ffv_t>(x);
84+
}
85+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <stan/math/prim.hpp>
2+
#include <test/unit/util.hpp>
3+
#include <gtest/gtest.h>
4+
5+
TEST(primFun, complex_schur_decompose) {
6+
using c_t = std::complex<double>;
7+
using stan::math::complex_schur_decompose_t;
8+
using stan::math::complex_schur_decompose_u;
9+
10+
// verify that A = U T U*
11+
12+
Eigen::MatrixXd A(3, 3);
13+
A << 0, 2, 2, 0, 0, 2, 1, 0, 1;
14+
auto A_t = complex_schur_decompose_t(A);
15+
auto A_u = complex_schur_decompose_u(A);
16+
auto A_recovered = A_u * A_t * A_u.adjoint();
17+
EXPECT_MATRIX_NEAR(A, stan::math::get_real(A_recovered), 1e-8);
18+
19+
Eigen::MatrixXcd B(3, 3);
20+
B << 0, 2, 2, 0, c_t(0, 1), 2, 1, 0, 1;
21+
22+
auto B_t = complex_schur_decompose_t(B);
23+
auto B_u = complex_schur_decompose_u(B);
24+
25+
auto B_recovered = B_u * B_t * B_u.adjoint();
26+
EXPECT_MATRIX_COMPLEX_NEAR(B, B_recovered, 1e-8);
27+
}

test/unit/math/prim/fun/fft_test.cpp

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ TEST(primFun, fft) {
4545
EXPECT_NEAR(imag(yb[2]), 2 * -6.53589838, 1e-6);
4646
}
4747

48-
template <typename T, typename U>
49-
void expect_complex_mat_eq(const T& x, const U& y, double tol = 1e-8) {
50-
EXPECT_EQ(x.rows(), y.rows());
51-
EXPECT_EQ(x.cols(), y.cols());
52-
for (int j = 0; j < x.cols(); ++j) {
53-
for (int i = 0; i < x.rows(); ++i) {
54-
EXPECT_FLOAT_EQ(real(x(i, j)), real(y(i, j)));
55-
EXPECT_FLOAT_EQ(imag(x(i, j)), imag(y(i, j)));
56-
}
57-
}
58-
}
59-
6048
TEST(primFun, inv_fft) {
6149
using c_t = std::complex<double>;
6250
using cv_t = Eigen::Matrix<std::complex<double>, -1, 1>;
@@ -74,7 +62,7 @@ TEST(primFun, inv_fft) {
7462
cv_t x1 = inv_fft(y1);
7563
cv_t x1_expected(1);
7664
x1_expected << c_t(-3.247, 1.98555);
77-
expect_complex_mat_eq(x1_expected, x1);
65+
EXPECT_MATRIX_COMPLEX_NEAR(x1_expected, x1, 1e-8);
7866

7967
EXPECT_EQ(1, x1.size());
8068
EXPECT_EQ(real(x1[0]), -3.247);
@@ -87,7 +75,7 @@ TEST(primFun, inv_fft) {
8775
EXPECT_EQ(3, y.size());
8876
Eigen::VectorXcd x_expected(3);
8977
x_expected << c_t(1, -2), c_t(-3, 5), c_t(-7, 11);
90-
expect_complex_mat_eq(x_expected, x);
78+
EXPECT_MATRIX_COMPLEX_NEAR(x_expected, x, 1e-8);
9179
}
9280

9381
TEST(primFun, fft2) {
@@ -115,7 +103,7 @@ TEST(primFun, fft2) {
115103
cm_t y12 = fft2(x12);
116104
cm_t y12_expected(1, 2);
117105
y12_expected << c_t(-7.6, -3.7), c_t(9.6, -4.1);
118-
expect_complex_mat_eq(y12_expected, y12);
106+
EXPECT_MATRIX_COMPLEX_NEAR(y12_expected, y12, 1e-8);
119107

120108
cm_t x33(3, 3);
121109
x33 << c_t(1, 2), c_t(3, -1.4), c_t(2, 1), c_t(3, -9), c_t(2, -1.3),
@@ -127,7 +115,7 @@ TEST(primFun, fft2) {
127115
c_t(-13.29326674, 20.88153533), c_t(-13.25262794, 15.82794549),
128116
c_t(4.160254038, 5.928718708), c_t(-11.34737206, -7.72794549),
129117
c_t(4.89326674, -1.98153533);
130-
expect_complex_mat_eq(y33_expected, y33);
118+
EXPECT_MATRIX_COMPLEX_NEAR(y33_expected, y33, 1e-8);
131119
}
132120

133121
TEST(primFunFFT, invfft2) {
@@ -153,7 +141,7 @@ TEST(primFunFFT, invfft2) {
153141
x13 << c_t(-2.3, 1.82), c_t(1.18, 9.32), c_t(1.15, -14.1);
154142
cm_t y13 = inv_fft2(x13);
155143
cm_t y13copy = inv_fft(x13.row(0));
156-
expect_complex_mat_eq(y13, y13copy.transpose());
144+
EXPECT_MATRIX_COMPLEX_NEAR(y13, y13copy.transpose(), 1e-8);
157145

158146
cm_t x33(3, 3);
159147
x33 << c_t(1, 2), c_t(3, -1.4), c_t(2, 1), c_t(3, -9), c_t(2, -1.3),
@@ -182,10 +170,10 @@ TEST(primFunFFT, invfft2) {
182170

183171
// check round trips inv_fft(fft(x))
184172
cm_t x33copy = inv_fft2(y33);
185-
expect_complex_mat_eq(x33, x33copy);
173+
EXPECT_MATRIX_COMPLEX_NEAR(x33, x33copy, 1e-8);
186174

187175
// check round trip fft(inv_fft(x))
188176
cm_t z33 = inv_fft2(x33);
189177
cm_t x33copy2 = fft2(z33);
190-
expect_complex_mat_eq(x33, x33copy2);
178+
EXPECT_MATRIX_COMPLEX_NEAR(x33, x33copy2, 1e-8);
191179
}

test/unit/util.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,26 @@
120120
EXPECT_NEAR(A_eval(i), B_eval(i), DELTA); \
121121
}
122122

123+
/**
124+
* Tests for elementwise equality of the input matrices
125+
* of std::complex<double>s with the EXPECT_FLOAT_EQ macro
126+
* from GTest.
127+
*
128+
* @param A first input matrix to compare
129+
* @param B second input matrix to compare
130+
*/
131+
#define EXPECT_MATRIX_COMPLEX_NEAR(A, B, DELTA) \
132+
{ \
133+
const Eigen::MatrixXcd& A_eval = A; \
134+
const Eigen::MatrixXcd& B_eval = B; \
135+
EXPECT_EQ(A_eval.rows(), B_eval.rows()); \
136+
EXPECT_EQ(A_eval.cols(), B_eval.cols()); \
137+
for (int i = 0; i < A_eval.size(); i++) { \
138+
EXPECT_NEAR(A_eval(i).real(), B_eval(i).real(), DELTA); \
139+
EXPECT_NEAR(A_eval(i).imag(), B_eval(i).imag(), DELTA); \
140+
} \
141+
}
142+
123143
/**
124144
* Tests if given types are the same type.
125145
*

0 commit comments

Comments
 (0)