From ba36559d8923700fa38ebf23107f8867971870cd Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 16 Feb 2024 13:08:51 -0500 Subject: [PATCH] Use host view to construct multi-index --- MParT/MultiIndices/MultiIndex.h | 10 ++++---- src/MultiIndices/FixedMultiIndexSet.cpp | 7 +++--- src/MultiIndices/MultiIndex.cpp | 31 ++++++++++++++----------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/MParT/MultiIndices/MultiIndex.h b/MParT/MultiIndices/MultiIndex.h index 2633033c..c8d3b336 100644 --- a/MParT/MultiIndices/MultiIndex.h +++ b/MParT/MultiIndices/MultiIndex.h @@ -6,6 +6,8 @@ #include #include +#include +#include namespace mpart { @@ -55,14 +57,12 @@ class MultiIndex { * * @param nzIndsIn indices of the nonzero values * @param nzValsIn values that are nonzero - * @param numNz number of nonzero values (i.e., length of nzIndsIn and nzValsIn) * @param lengthIn dimension of the index (numNz==lengthIn iff all values are nonzero) */ - MultiIndex(unsigned int* nzIndsIn, - unsigned int* nzValsIn, - unsigned int numNz, + MultiIndex(Kokkos::View const& nzIndsIn, + Kokkos::View const& nzValsIn, unsigned int lengthIn); - + /** Uses a dense vector description of the multiindex, defined through a pointer, and extracts the nonzero components. @param[in] fullVec A pointer the memory containing the dense multiindex. diff --git a/src/MultiIndices/FixedMultiIndexSet.cpp b/src/MultiIndices/FixedMultiIndexSet.cpp index 4fe48868..a96cb3c1 100644 --- a/src/MultiIndices/FixedMultiIndexSet.cpp +++ b/src/MultiIndices/FixedMultiIndexSet.cpp @@ -347,10 +347,9 @@ MultiIndexSet FixedMultiIndexSet::Unfix() const for(int term = 0; term < h_nzStarts.extent(0)-1; term++){ unsigned int start = h_nzStarts(term); unsigned int end = h_nzStarts(term+1); - unsigned int numNz = end - start; - unsigned int* nzIndTerm = h_nzDims.data() + start; - unsigned int* nzValTerm = h_nzOrders.data() + start; - MultiIndex midx_term {nzIndTerm, nzValTerm, numNz, this->dim}; + auto nzIndTerm = Kokkos::subview(h_nzDims, Kokkos::pair(start, end)); + auto nzValTerm = Kokkos::subview(h_nzOrders, Kokkos::pair(start, end)); + MultiIndex midx_term {nzIndTerm, nzValTerm, this->dim}; output.AddActive(midx_term); } return output; diff --git a/src/MultiIndices/MultiIndex.cpp b/src/MultiIndices/MultiIndex.cpp index 867fd4cd..f0fa3f97 100644 --- a/src/MultiIndices/MultiIndex.cpp +++ b/src/MultiIndices/MultiIndex.cpp @@ -55,21 +55,24 @@ MultiIndex::MultiIndex(std::initializer_list const& indIn) : Multi } } -MultiIndex::MultiIndex(unsigned int* nzIndsIn, - unsigned int* nzValsIn, - unsigned int numNz, - unsigned int lengthIn) : length(lengthIn), - maxValue(0), - totalOrder(0) -{ - for(unsigned int i=0; i0){ - nzInds.push_back(nzIndsIn[i]); - nzVals.push_back(nzValsIn[i]); - maxValue = std::max(maxValue, nzValsIn[i]); - totalOrder += nzValsIn[i]; +MultiIndex::MultiIndex(Kokkos::View const& nzIndsIn, + Kokkos::View const& nzValsIn, + unsigned int lengthIn): length(lengthIn), maxValue(0), totalOrder(0) { + unsigned int numNz = nzIndsIn.size(); + if(numNz != nzValsIn.size()){ + std::stringstream ss; + ss << "MultiIndex::MultiIndex: nzIndsIn and nzValsIn must have the same number" + << "of elements. Found " << numNz << " and " << nzValsIn.size() << " elements."; + throw std::runtime_error(ss.str().c_str()); + } + for(unsigned int i=0; i0){ + nzInds.push_back(nzIndsIn(i)); + nzVals.push_back(nzValsIn(i)); + maxValue = std::max(maxValue, nzValsIn(i)); + totalOrder += nzValsIn(i); + } } - } } std::vectorMultiIndex::Vector() const