Skip to content

Commit

Permalink
Add facilities for NonzeroDiagonalEntries of fmset on device and unco…
Browse files Browse the repository at this point in the history
…mpressed
  • Loading branch information
dannys4 committed Feb 16, 2024
1 parent e815a31 commit fdc5f9a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
42 changes: 29 additions & 13 deletions src/MultiIndices/FixedMultiIndexSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,26 +244,42 @@ std::vector<unsigned int> FixedMultiIndexSet<Kokkos::HostSpace>::IndexToMulti(un
return output;
}


template<typename MemorySpace>
std::vector<unsigned int> FixedMultiIndexSet<MemorySpace>::NonzeroDiagonalEntries() const
{
assert(false);
return std::vector<unsigned int>();
}

template<>
std::vector<unsigned int> FixedMultiIndexSet<Kokkos::HostSpace>::NonzeroDiagonalEntries() const
{
std::vector<unsigned int> CompressedNonzeroDiagonalEntries(
const Kokkos::View<unsigned int*, Kokkos::HostSpace> &nzStarts,
const Kokkos::View<unsigned int*, Kokkos::HostSpace> &nzDims,
unsigned int dim) {
std::vector<unsigned int> output;
if(!isCompressed) throw std::runtime_error("NonzeroDiagonalEntries only works for compressed multiindex sets");
for(unsigned int midx = 0; midx < nzStarts.extent(0)-1; midx++){
if(nzStarts(midx) == nzStarts(midx+1)) continue;
if(nzDims(nzStarts(midx+1)-1) == this->dim-1) output.push_back(midx);
if(nzDims(nzStarts(midx+1)-1) == dim-1) output.push_back(midx);
}
return output;
}

std::vector<unsigned int> UncompressedNonzeroDiagonalEntries(
const Kokkos::View<unsigned int*, Kokkos::HostSpace> &orders,
unsigned int dim) {
std::vector<unsigned int> output;
for(unsigned int midx = 0; midx < orders.extent(0)/dim; midx++){
bool isDiagonal = orders((midx+1)*dim-1) > 0;
if(isDiagonal) output.push_back(midx);
}
return output;
}

template<typename MemorySpace>
std::vector<unsigned int> FixedMultiIndexSet<MemorySpace>::NonzeroDiagonalEntries() const
{
Kokkos::View<unsigned int*, Kokkos::HostSpace> h_nzStarts = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzStarts);
Kokkos::View<unsigned int*, Kokkos::HostSpace> h_nzDims = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzDims);
Kokkos::View<unsigned int*, Kokkos::HostSpace> h_nzOrders = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), nzOrders);
if(isCompressed) {
return CompressedNonzeroDiagonalEntries(h_nzStarts, h_nzDims, this->dim);
} else {
return UncompressedNonzeroDiagonalEntries(h_nzOrders, this->dim);
}
}

template<typename MemorySpace>
int FixedMultiIndexSet<MemorySpace>::MultiToIndex(std::vector<unsigned int> const& multi) const
{
Expand Down
9 changes: 9 additions & 0 deletions tests/MultiIndices/Test_MultiIndexSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,15 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) {
expected_size += multi.HasNonzeroEnd();
}
REQUIRE( expected_size == inds.size() );
FixedMultiIndexSet<Kokkos::HostSpace> fixedSet = set.Fix(true);
FixedMultiIndexSet<Kokkos::HostSpace> fixedSet2 = set.Fix(false);
std::vector<unsigned int> inds_fixed = fixedSet.NonzeroDiagonalEntries();
std::vector<unsigned int> inds_fixed2 = fixedSet2.NonzeroDiagonalEntries();
std::sort(inds.begin(), inds.end());
std::sort(inds_fixed.begin(), inds_fixed.end());
std::sort(inds_fixed2.begin(), inds_fixed2.end());
REQUIRE( inds == inds_fixed );
REQUIRE( inds == inds_fixed2 );
MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrderLimiter(maxOrder));
inds = full_set.NonzeroDiagonalEntries();
REQUIRE( inds.size() == full_set.Size() );
Expand Down

0 comments on commit fdc5f9a

Please sign in to comment.