Skip to content

Commit

Permalink
Remove total order from separable and nonzerodiag limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
dannys4 committed Feb 15, 2024
1 parent 4bc943d commit 8a0e210
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 38 deletions.
40 changes: 15 additions & 25 deletions MParT/MultiIndices/MultiIndexLimiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,50 +27,40 @@ namespace MultiIndexLimiter{

};

/** @class SeparableTotalOrderLimiter
@brief Same as TotalOrder without cross-terms
@details This limter only allows terms that satisfy
\f$\|\mathbf{j}\|_1\leq p_U\f$, where \f$\mathbf{j}\f$
is the multiindex, and \f$p_U\f$ is a nonnegative integer passed to the
constructor of this class.
/** @class Separable
@brief Restricts multi-indices to refuse cross-terms
@details This limiter only allows terms that satisfy
\f$\mathbf{j}_d = 0\f$ *or* \f$\mathbf{j}_d=\|\mathbf{j}\|_1\f$,
where \f$\mathbf{j}\in\mathbb{N}_0^d\f$ is the multiindex.
*/
class SeparableTotalOrder{
class Separable{

public:

SeparableTotalOrder(unsigned int totalOrderIn) : totalOrder(totalOrderIn){};
Separable(){};

bool operator()(MultiIndex const& multi){
unsigned int sum = multi.Sum();
return (sum <= totalOrder) && (!multi.HasNonzeroEnd() || sum == multi.Get(multi.Length()-1));
return !multi.HasNonzeroEnd() || sum == multi.Get(multi.Length()-1);
};

private:
const unsigned int totalOrder;

};

/** @class NonzeroDiagTotalOrder
@brief Same as TotalOrder, except without any term that has nonzero diagonal entries
@details This limter only allows terms that satisfy
\f$\|\mathbf{j}\|_1\leq p_U\f$, where \f$\mathbf{j}\f$
is the multiindex, and \f$p_U\f$ is a nonnegative integer passed to the
constructor of this class.
/** @class NonzeroDiag
@brief Restricts acceptable to any term that has nonzero diagonal entries
@details This limiter only allows terms that satisfy
\f$\mathbf{j}_d \neq 0\f$ when the multi-index \f$\mathbf{j}\in\mathbb{N}_0^d\f$.
*/
class NonzeroDiagTotalOrder{
class NonzeroDiag{

public:

NonzeroDiagTotalOrder(unsigned int totalOrderIn) : totalOrder(totalOrderIn){};
NonzeroDiag() {};

bool operator()(MultiIndex const& multi){
unsigned int sum = multi.Sum();
return (sum <= totalOrder) && (multi.HasNonzeroEnd());
return (multi.HasNonzeroEnd());
};

private:
const unsigned int totalOrder;

};


Expand Down
4 changes: 2 additions & 2 deletions bindings/julia/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ void mpart::binding::MultiIndexWrapper(jlcxx::Module &mod) {
});

mod.method("CreateTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::None()); });
mod.method("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder)); });
mod.method("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder)); });
mod.method("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::Separable()); });
mod.method("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiag()); });

mod.set_override_module(jl_base_module);
mod.method("sum", [](MultiIndex const& idx){ return idx.Sum(); });
Expand Down
4 changes: 2 additions & 2 deletions bindings/matlab/src/MultiIndexSet_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ MEX_DEFINE(MultiIndexSet_newTotalOrder) (int nlhs, mxArray* plhs[],
const std::string limiter_type = input.get<std::string>(2);
MultiIndexSet::LimiterType limiter;
if(limiter_type == "separable")
limiter = MultiIndexLimiter::SeparableTotalOrder(order);
limiter = MultiIndexLimiter::Separable();
else if(limiter_type == "nonzeroDiag")
limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(order);
limiter = MultiIndexLimiter::NonzeroDiag();
else
limiter = MultiIndexLimiter::None();
MultiIndexSet toCreate = MultiIndexSet::CreateTotalOrder(dim, order, limiter);
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
.def("Size", &MultiIndexSet::Size, "Retrieves the number of elements in this MultiIndexSet")

.def_static("CreateTotalOrder", &MultiIndexSet::CreateTotalOrder, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None())
.def_static("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder));}, py::arg("length"), py::arg("maxOrder"))
.def_static("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder));}, py::arg("length"), py::arg("maxOrder"))
.def_static("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::Separable());}, py::arg("length"), py::arg("maxOrder"))
.def_static("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiag());}, py::arg("length"), py::arg("maxOrder"))

.def_static("CreateTensorProduct", &MultiIndexSet::CreateTensorProduct, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None())

Expand Down
2 changes: 1 addition & 1 deletion src/MapFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ std::shared_ptr<ConditionalMapBase<MemorySpace>> CreateSigmoidExpansionTemplate(
return CreateSigmoidExpansionTemplate<MemorySpace, OffdiagEval, Rectifier, SigmoidType, EdgeType>(
mset_offdiag, mset, centers, edgeWidth);
}
MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(totalOrder));
MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder, MultiIndexLimiter::NonzeroDiag());
FixedMultiIndexSet<MemorySpace> fmset_diag = mset.Fix(true).ToDevice<MemorySpace>();
FixedMultiIndexSet<MemorySpace> fmset_offdiag {inputDim-1, totalOrder};
return CreateSigmoidExpansionTemplate<MemorySpace, OffdiagEval, Rectifier, SigmoidType, EdgeType>(
Expand Down
6 changes: 3 additions & 3 deletions tests/MultiIndices/Test_MultiIndexSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) {

REQUIRE( mset.Size()==((maxOrder+1)*(maxOrder+2)/2));

MultiIndexSet mset_sep = MultiIndexSet::CreateTotalOrder(dim+1, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder));
MultiIndexSet mset_sep = MultiIndexSet::CreateTotalOrder(dim+1, maxOrder, MultiIndexLimiter::Separable());
REQUIRE( mset_sep.Size()== ((maxOrder+1)*(maxOrder+2)/2) + maxOrder);
REQUIRE( mset_sep.NonzeroDiagonalEntries().size() == maxOrder);

Expand Down Expand Up @@ -487,7 +487,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) {

SECTION("NonzeroDiagonalEntries") {
unsigned int dim = 2, maxOrder = 3;
MultiIndexSet set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder));
MultiIndexSet set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::Separable());
std::vector<unsigned int> inds = set.NonzeroDiagonalEntries();
for(unsigned int ind: inds) {
MultiIndex multi = set.at(ind);
Expand All @@ -499,7 +499,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) {
expected_size += multi.HasNonzeroEnd();
}
REQUIRE( expected_size == inds.size() );
MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder));
MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiag());
inds = full_set.NonzeroDiagonalEntries();
REQUIRE( inds.size() == full_set.Size() );
for(int i = 0; i < full_set.Size(); ++i) {
Expand Down
6 changes: 3 additions & 3 deletions tests/Test_RectifiedMultivariateExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST_CASE("RectifiedMultivariateExpansion, Unrectified", "[RMVE_NoRect]") {
BasisEvaluator<BasisHomogeneity::Homogeneous, T> basis_eval_offdiag;
BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous, Kokkos::pair<T, T>, Identity> basis_eval_diag{dim};
FixedMultiIndexSet<MemorySpace> fmset_offdiag(dim-1, maxOrder);
auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder);
auto limiter = MultiIndexLimiter::NonzeroDiag();
FixedMultiIndexSet<MemorySpace> fmset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter).Fix(true);
MultivariateExpansionWorker<OffdiagEval_T, MemorySpace> worker_off(fmset_offdiag, basis_eval_offdiag);
MultivariateExpansionWorker<DiagEval_T, MemorySpace> worker_diag(fmset_diag, basis_eval_diag);
Expand Down Expand Up @@ -179,7 +179,7 @@ TEMPLATE_TEST_CASE("Single Sigmoid RectifiedMultivariateExpansion","[single_sigm
unsigned int maxOrder = 4;
FixedMultiIndexSet<MemorySpace> fmset_offdiag (dim-1, maxOrder);
unsigned int sigmoid_order = 4; // const, linear, left ET, right ET, sigmoid
auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(sigmoid_order);
auto limiter = MultiIndexLimiter::NonzeroDiag();
MultiIndexSet mset_diag = MultiIndexSet::CreateTotalOrder(dim, sigmoid_order, limiter);
FixedMultiIndexSet<MemorySpace> fmset_diag = mset_diag.Fix(true);
// Setup expansion
Expand Down Expand Up @@ -282,7 +282,7 @@ TEMPLATE_TEST_CASE("Multiple Sigmoid RectifiedMultivariateExpansion","[multi_sig
unsigned int maxOrder = 4;
unsigned int dim = 3;
FixedMultiIndexSet<MemorySpace> fmset_offdiag(dim-1, maxOrder);
auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder);
auto limiter = MultiIndexLimiter::NonzeroDiag();
MultiIndexSet mset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter);
FixedMultiIndexSet<MemorySpace> fmset_diag = mset_diag.Fix(true);

Expand Down

0 comments on commit 8a0e210

Please sign in to comment.