diff --git a/src/MultiIndices/FixedMultiIndexSet.cpp b/src/MultiIndices/FixedMultiIndexSet.cpp index 9a539366..4fe48868 100644 --- a/src/MultiIndices/FixedMultiIndexSet.cpp +++ b/src/MultiIndices/FixedMultiIndexSet.cpp @@ -244,26 +244,42 @@ std::vector FixedMultiIndexSet::IndexToMulti(un return output; } - -template -std::vector FixedMultiIndexSet::NonzeroDiagonalEntries() const -{ - assert(false); - return std::vector(); -} - -template<> -std::vector FixedMultiIndexSet::NonzeroDiagonalEntries() const -{ +std::vector CompressedNonzeroDiagonalEntries( + const Kokkos::View &nzStarts, + const Kokkos::View &nzDims, + unsigned int dim) { std::vector output; - if(!isCompressed) throw std::runtime_error("NonzeroDiagonalEntries only works for compressed multiindex sets"); for(unsigned int midx = 0; midx < nzStarts.extent(0)-1; midx++){ if(nzStarts(midx) == nzStarts(midx+1)) continue; - if(nzDims(nzStarts(midx+1)-1) == this->dim-1) output.push_back(midx); + if(nzDims(nzStarts(midx+1)-1) == dim-1) output.push_back(midx); + } + return output; +} + +std::vector UncompressedNonzeroDiagonalEntries( + const Kokkos::View &orders, + unsigned int dim) { + std::vector output; + for(unsigned int midx = 0; midx < orders.extent(0)/dim; midx++){ + bool isDiagonal = orders((midx+1)*dim-1) > 0; + if(isDiagonal) output.push_back(midx); } return output; } +template +std::vector FixedMultiIndexSet::NonzeroDiagonalEntries() const +{ + Kokkos::View h_nzStarts = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzStarts); + Kokkos::View h_nzDims = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzDims); + Kokkos::View h_nzOrders = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzOrders); + if(isCompressed) { + return CompressedNonzeroDiagonalEntries(h_nzStarts, h_nzDims, this->dim); + } else { + return UncompressedNonzeroDiagonalEntries(h_nzOrders, this->dim); + } +} + template int FixedMultiIndexSet::MultiToIndex(std::vector const& multi) const { diff --git a/tests/MultiIndices/Test_MultiIndexSet.cpp b/tests/MultiIndices/Test_MultiIndexSet.cpp index ac5ec68d..7ed8472a 100644 --- a/tests/MultiIndices/Test_MultiIndexSet.cpp +++ b/tests/MultiIndices/Test_MultiIndexSet.cpp @@ -499,6 +499,15 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) { expected_size += multi.HasNonzeroEnd(); } REQUIRE( expected_size == inds.size() ); + FixedMultiIndexSet fixedSet = set.Fix(true); + FixedMultiIndexSet fixedSet2 = set.Fix(false); + std::vector inds_fixed = fixedSet.NonzeroDiagonalEntries(); + std::vector inds_fixed2 = fixedSet2.NonzeroDiagonalEntries(); + std::sort(inds.begin(), inds.end()); + std::sort(inds_fixed.begin(), inds_fixed.end()); + std::sort(inds_fixed2.begin(), inds_fixed2.end()); + REQUIRE( inds == inds_fixed ); + REQUIRE( inds == inds_fixed2 ); MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrderLimiter(maxOrder)); inds = full_set.NonzeroDiagonalEntries(); REQUIRE( inds.size() == full_set.Size() );