Skip to content

Commit

Permalink
Use host view to construct multi-index
Browse files Browse the repository at this point in the history
  • Loading branch information
dannys4 committed Feb 16, 2024
1 parent fdc5f9a commit ba36559
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
10 changes: 5 additions & 5 deletions MParT/MultiIndices/MultiIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <vector>

#include <Eigen/Core>
#include <Kokkos_Core.hpp>
#include <iostream>

namespace mpart {

Expand Down Expand Up @@ -55,14 +57,12 @@ class MultiIndex {
*
* @param nzIndsIn indices of the nonzero values
* @param nzValsIn values that are nonzero
* @param numNz number of nonzero values (i.e., length of nzIndsIn and nzValsIn)
* @param lengthIn dimension of the index (numNz==lengthIn iff all values are nonzero)
*/
MultiIndex(unsigned int* nzIndsIn,
unsigned int* nzValsIn,
unsigned int numNz,
MultiIndex(Kokkos::View<unsigned int*, Kokkos::HostSpace> const& nzIndsIn,
Kokkos::View<unsigned int*, Kokkos::HostSpace> const& nzValsIn,
unsigned int lengthIn);

/** Uses a dense vector description of the multiindex, defined through a pointer,
and extracts the nonzero components.
@param[in] fullVec A pointer the memory containing the dense multiindex.
Expand Down
7 changes: 3 additions & 4 deletions src/MultiIndices/FixedMultiIndexSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,9 @@ MultiIndexSet FixedMultiIndexSet<MemorySpace>::Unfix() const
for(int term = 0; term < h_nzStarts.extent(0)-1; term++){
unsigned int start = h_nzStarts(term);
unsigned int end = h_nzStarts(term+1);
unsigned int numNz = end - start;
unsigned int* nzIndTerm = h_nzDims.data() + start;
unsigned int* nzValTerm = h_nzOrders.data() + start;
MultiIndex midx_term {nzIndTerm, nzValTerm, numNz, this->dim};
auto nzIndTerm = Kokkos::subview(h_nzDims, Kokkos::pair<unsigned int, unsigned int>(start, end));
auto nzValTerm = Kokkos::subview(h_nzOrders, Kokkos::pair<unsigned int, unsigned int>(start, end));
MultiIndex midx_term {nzIndTerm, nzValTerm, this->dim};
output.AddActive(midx_term);
}
return output;
Expand Down
31 changes: 17 additions & 14 deletions src/MultiIndices/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,24 @@ MultiIndex::MultiIndex(std::initializer_list<unsigned int> const& indIn) : Multi
}
}

MultiIndex::MultiIndex(unsigned int* nzIndsIn,
unsigned int* nzValsIn,
unsigned int numNz,
unsigned int lengthIn) : length(lengthIn),
maxValue(0),
totalOrder(0)
{
for(unsigned int i=0; i<numNz; ++i){
if(nzValsIn[i]>0){
nzInds.push_back(nzIndsIn[i]);
nzVals.push_back(nzValsIn[i]);
maxValue = std::max<unsigned int>(maxValue, nzValsIn[i]);
totalOrder += nzValsIn[i];
MultiIndex::MultiIndex(Kokkos::View<unsigned int*, Kokkos::HostSpace> const& nzIndsIn,
Kokkos::View<unsigned int*, Kokkos::HostSpace> const& nzValsIn,
unsigned int lengthIn): length(lengthIn), maxValue(0), totalOrder(0) {
unsigned int numNz = nzIndsIn.size();
if(numNz != nzValsIn.size()){
std::stringstream ss;
ss << "MultiIndex::MultiIndex: nzIndsIn and nzValsIn must have the same number"
<< "of elements. Found " << numNz << " and " << nzValsIn.size() << " elements.";
throw std::runtime_error(ss.str().c_str());
}
for(unsigned int i=0; i<numNz; ++i){
if(nzValsIn(i)>0){
nzInds.push_back(nzIndsIn(i));
nzVals.push_back(nzValsIn(i));
maxValue = std::max<unsigned int>(maxValue, nzValsIn(i));
totalOrder += nzValsIn(i);
}
}
}
}

std::vector<unsigned int>MultiIndex::Vector() const
Expand Down

0 comments on commit ba36559

Please sign in to comment.