Skip to content

Commit 58dc2de

Browse files
authored
Merge pull request #520 from ValeevGroup/kshitij/feature/qr_solve
Support column pivoted QR solve
2 parents ce692b5 + 272a771 commit 58dc2de

File tree

4 files changed

+76
-6
lines changed

4 files changed

+76
-6
lines changed

src/TiledArray/math/linalg/non-distributed/qr.h

+16
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ auto householder_qr(const ArrayV& V, TiledRange q_trange = TiledRange(),
3434
}
3535
}
3636

37+
template <typename ArrayA, typename ArrayB, typename T = ArrayB::numeric_type>
38+
auto qr_solve(const ArrayA& A, const ArrayB& B,
39+
const TiledArray::detail::real_t<T> cond = 1e8,
40+
TiledRange x_trange = TiledRange()) {
41+
(void)detail::array_traits<ArrayB>{};
42+
auto& world = B.world();
43+
auto A_eig = detail::make_matrix(A);
44+
auto B_eig = detail::make_matrix(B);
45+
TA_LAPACK_ON_RANK_ZERO(qr_solve, world, A_eig, B_eig, cond);
46+
world.gop.broadcast_serializable(A_eig, 0);
47+
world.gop.broadcast_serializable(B_eig, 0);
48+
if (x_trange.rank() == 0) x_trange = B.trange();
49+
auto X = eigen_to_array<ArrayB>(world, x_trange, B_eig);
50+
return X;
51+
}
52+
3753
} // namespace TiledArray::math::linalg::non_distributed
3854

3955
#endif

src/TiledArray/math/linalg/rank-local.cpp

+23-6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,22 @@ void cholesky_lsolve(Op transpose, Matrix<T>& A, Matrix<T>& X) {
112112
TA_LAPACK(trtrs, uplo, transpose, diag, n, nrhs, a, lda, b, ldb);
113113
}
114114

115+
template <typename T>
116+
void qr_solve(Matrix<T>& A, Matrix<T>& B,
117+
const TiledArray::detail::real_t<T> cond) {
118+
integer m = A.rows();
119+
integer n = A.cols();
120+
integer nrhs = B.cols();
121+
T* a = A.data();
122+
integer lda = A.rows();
123+
T* b = B.data();
124+
integer ldb = B.rows();
125+
std::vector<integer> jpiv(n);
126+
const TiledArray::detail::real_t<T> rcond = 1 / cond;
127+
integer rank = -1;
128+
TA_LAPACK(gelsy, m, n, nrhs, a, lda, b, ldb, jpiv.data(), rcond, &rank);
129+
}
130+
115131
template <typename T>
116132
void heig(Matrix<T>& A, std::vector<TiledArray::detail::real_t<T>>& W) {
117133
auto jobz = lapack::Job::Vec;
@@ -250,7 +266,7 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
250266
lapack::orgqr(m, n, k, v, ldv, tau.data());
251267
}
252268

253-
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR) \
269+
#define TA_LAPACK_EXPLICIT(MATRIX, VECTOR, DOUBLE) \
254270
template void cholesky(MATRIX&); \
255271
template void cholesky_linv(MATRIX&); \
256272
template void cholesky_solve(MATRIX&, MATRIX&); \
@@ -261,11 +277,12 @@ void householder_qr(Matrix<T>& V, Matrix<T>& R) {
261277
template void lu_solve(MATRIX&, MATRIX&); \
262278
template void lu_inv(MATRIX&); \
263279
template void householder_qr<true>(MATRIX&, MATRIX&); \
264-
template void householder_qr<false>(MATRIX&, MATRIX&);
280+
template void householder_qr<false>(MATRIX&, MATRIX&); \
281+
template void qr_solve(MATRIX&, MATRIX&, DOUBLE)
265282

266-
TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>);
267-
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>);
268-
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>);
269-
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>);
283+
TA_LAPACK_EXPLICIT(Matrix<double>, std::vector<double>, double );
284+
TA_LAPACK_EXPLICIT(Matrix<float>, std::vector<float>, float);
285+
TA_LAPACK_EXPLICIT(Matrix<std::complex<double>>, std::vector<double>, double);
286+
TA_LAPACK_EXPLICIT(Matrix<std::complex<float>>, std::vector<float>, float);
270287

271288
} // namespace TiledArray::math::linalg::rank_local

src/TiledArray/math/linalg/rank-local.h

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void cholesky_solve(Matrix<T> &A, Matrix<T> &X);
4141
template <typename T>
4242
void cholesky_lsolve(Op transpose, Matrix<T> &A, Matrix<T> &X);
4343

44+
template <typename T>
45+
void qr_solve(Matrix<T> &A, Matrix<T> &B,
46+
const TiledArray::detail::real_t<T> cond = 1e8);
47+
4448
template <typename T>
4549
void heig(Matrix<T> &A, std::vector<TiledArray::detail::real_t<T>> &W);
4650

tests/linalg.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,39 @@ BOOST_AUTO_TEST_CASE(cholesky_lsolve) {
753753
GlobalFixture::world->gop.fence();
754754
}
755755

756+
BOOST_AUTO_TEST_CASE(qr_solve) {
757+
GlobalFixture::world->gop.fence();
758+
759+
auto trange = gen_trange(N, {128ul});
760+
761+
auto ref_ta = TA::make_array<TA::TArray<double>>(
762+
*GlobalFixture::world, trange,
763+
[this](TA::Tensor<double>& t, TA::Range const& range) -> double {
764+
return this->make_ta_reference(t, range);
765+
});
766+
767+
auto iden = non_dist::qr_solve(ref_ta, ref_ta);
768+
769+
BOOST_CHECK(iden.trange() == ref_ta.trange());
770+
771+
TA::foreach_inplace(iden, [](TA::Tensor<double>& tile) {
772+
auto range = tile.range();
773+
auto lo = range.lobound_data();
774+
auto up = range.upbound_data();
775+
for (auto m = lo[0]; m < up[0]; ++m)
776+
for (auto n = lo[1]; n < up[1]; ++n)
777+
if (m == n) {
778+
tile(m, n) -= 1.;
779+
}
780+
});
781+
782+
double epsilon = N * N * std::numeric_limits<double>::epsilon();
783+
double norm = iden("i,j").norm(*GlobalFixture::world).get();
784+
785+
BOOST_CHECK_SMALL(norm, epsilon);
786+
GlobalFixture::world->gop.fence();
787+
}
788+
756789
BOOST_AUTO_TEST_CASE(lu_solve) {
757790
GlobalFixture::world->gop.fence();
758791

0 commit comments

Comments
 (0)