Skip to content

Commit 272a771

Browse files
committed
unittest for qr_solve
1 parent 3a81d01 commit 272a771

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/linalg.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,39 @@ BOOST_AUTO_TEST_CASE(cholesky_lsolve) {
753753
GlobalFixture::world->gop.fence();
754754
}
755755

756+
BOOST_AUTO_TEST_CASE(qr_solve) {
757+
GlobalFixture::world->gop.fence();
758+
759+
auto trange = gen_trange(N, {128ul});
760+
761+
auto ref_ta = TA::make_array<TA::TArray<double>>(
762+
*GlobalFixture::world, trange,
763+
[this](TA::Tensor<double>& t, TA::Range const& range) -> double {
764+
return this->make_ta_reference(t, range);
765+
});
766+
767+
auto iden = non_dist::qr_solve(ref_ta, ref_ta);
768+
769+
BOOST_CHECK(iden.trange() == ref_ta.trange());
770+
771+
TA::foreach_inplace(iden, [](TA::Tensor<double>& tile) {
772+
auto range = tile.range();
773+
auto lo = range.lobound_data();
774+
auto up = range.upbound_data();
775+
for (auto m = lo[0]; m < up[0]; ++m)
776+
for (auto n = lo[1]; n < up[1]; ++n)
777+
if (m == n) {
778+
tile(m, n) -= 1.;
779+
}
780+
});
781+
782+
double epsilon = N * N * std::numeric_limits<double>::epsilon();
783+
double norm = iden("i,j").norm(*GlobalFixture::world).get();
784+
785+
BOOST_CHECK_SMALL(norm, epsilon);
786+
GlobalFixture::world->gop.fence();
787+
}
788+
756789
BOOST_AUTO_TEST_CASE(lu_solve) {
757790
GlobalFixture::world->gop.fence();
758791

0 commit comments

Comments
 (0)