diff --git a/MParT/MultiIndices/MultiIndexLimiter.h b/MParT/MultiIndices/MultiIndexLimiter.h index 67e9ce22..959dccae 100644 --- a/MParT/MultiIndices/MultiIndexLimiter.h +++ b/MParT/MultiIndices/MultiIndexLimiter.h @@ -27,50 +27,40 @@ namespace MultiIndexLimiter{ }; - /** @class SeparableTotalOrderLimiter - @brief Same as TotalOrder without cross-terms - @details This limter only allows terms that satisfy - \f$\|\mathbf{j}\|_1\leq p_U\f$, where \f$\mathbf{j}\f$ - is the multiindex, and \f$p_U\f$ is a nonnegative integer passed to the - constructor of this class. + /** @class Separable + @brief Restricts multi-indices to refuse cross-terms + @details This limiter only allows terms that satisfy + \f$\mathbf{j}_d = 0\f$ *or* \f$\mathbf{j}_d=\|\mathbf{j}\|_1\f$, + where \f$\mathbf{j}\in\mathbb{N}_0^d\f$ is the multiindex. */ - class SeparableTotalOrder{ + class Separable{ public: - SeparableTotalOrder(unsigned int totalOrderIn) : totalOrder(totalOrderIn){}; + Separable(){}; bool operator()(MultiIndex const& multi){ unsigned int sum = multi.Sum(); - return (sum <= totalOrder) && (!multi.HasNonzeroEnd() || sum == multi.Get(multi.Length()-1)); + return !multi.HasNonzeroEnd() || sum == multi.Get(multi.Length()-1); }; - private: - const unsigned int totalOrder; - }; - /** @class NonzeroDiagTotalOrder - @brief Same as TotalOrder, except without any term that has nonzero diagonal entries - @details This limter only allows terms that satisfy - \f$\|\mathbf{j}\|_1\leq p_U\f$, where \f$\mathbf{j}\f$ - is the multiindex, and \f$p_U\f$ is a nonnegative integer passed to the - constructor of this class. + /** @class NonzeroDiag + @brief Restricts acceptable to any term that has nonzero diagonal entries + @details This limiter only allows terms that satisfy + \f$\mathbf{j}_d \neq 0\f$ when the multi-index \f$\mathbf{j}\in\mathbb{N}_0^d\f$. */ - class NonzeroDiagTotalOrder{ + class NonzeroDiag{ public: - NonzeroDiagTotalOrder(unsigned int totalOrderIn) : totalOrder(totalOrderIn){}; + NonzeroDiag() {}; bool operator()(MultiIndex const& multi){ - unsigned int sum = multi.Sum(); - return (sum <= totalOrder) && (multi.HasNonzeroEnd()); + return (multi.HasNonzeroEnd()); }; - private: - const unsigned int totalOrder; - }; diff --git a/bindings/julia/src/MultiIndex.cpp b/bindings/julia/src/MultiIndex.cpp index bb432569..c9bc4f4d 100644 --- a/bindings/julia/src/MultiIndex.cpp +++ b/bindings/julia/src/MultiIndex.cpp @@ -77,8 +77,8 @@ void mpart::binding::MultiIndexWrapper(jlcxx::Module &mod) { }); mod.method("CreateTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::None()); }); - mod.method("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder)); }); - mod.method("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder)); }); + mod.method("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::Separable()); }); + mod.method("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiag()); }); mod.set_override_module(jl_base_module); mod.method("sum", [](MultiIndex const& idx){ return idx.Sum(); }); diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 5d4787b5..02652352 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -37,9 +37,9 @@ MEX_DEFINE(MultiIndexSet_newTotalOrder) (int nlhs, mxArray* plhs[], const std::string limiter_type = input.get(2); MultiIndexSet::LimiterType limiter; if(limiter_type == "separable") - limiter = MultiIndexLimiter::SeparableTotalOrder(order); + limiter = MultiIndexLimiter::Separable(); else if(limiter_type == "nonzeroDiag") - limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(order); + limiter = MultiIndexLimiter::NonzeroDiag(); else limiter = MultiIndexLimiter::None(); MultiIndexSet toCreate = MultiIndexSet::CreateTotalOrder(dim, order, limiter); diff --git a/bindings/python/src/MultiIndex.cpp b/bindings/python/src/MultiIndex.cpp index 0cb77fd0..ff34b6db 100644 --- a/bindings/python/src/MultiIndex.cpp +++ b/bindings/python/src/MultiIndex.cpp @@ -120,8 +120,8 @@ void mpart::binding::MultiIndexWrapper(py::module &m) .def("Size", &MultiIndexSet::Size, "Retrieves the number of elements in this MultiIndexSet") .def_static("CreateTotalOrder", &MultiIndexSet::CreateTotalOrder, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None()) - .def_static("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder));}, py::arg("length"), py::arg("maxOrder")) - .def_static("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder));}, py::arg("length"), py::arg("maxOrder")) + .def_static("CreateSeparableTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::Separable());}, py::arg("length"), py::arg("maxOrder")) + .def_static("CreateNonzeroDiagTotalOrder", [](unsigned int length, unsigned int maxOrder){return MultiIndexSet::CreateTotalOrder(length, maxOrder, MultiIndexLimiter::NonzeroDiag());}, py::arg("length"), py::arg("maxOrder")) .def_static("CreateTensorProduct", &MultiIndexSet::CreateTensorProduct, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None()) diff --git a/src/MapFactory.cpp b/src/MapFactory.cpp index c4eca6f5..d810933f 100644 --- a/src/MapFactory.cpp +++ b/src/MapFactory.cpp @@ -269,7 +269,7 @@ std::shared_ptr> CreateSigmoidExpansionTemplate( return CreateSigmoidExpansionTemplate( mset_offdiag, mset, centers, edgeWidth); } - MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(totalOrder)); + MultiIndexSet mset = MultiIndexSet::CreateTotalOrder(inputDim, totalOrder, MultiIndexLimiter::NonzeroDiag()); FixedMultiIndexSet fmset_diag = mset.Fix(true).ToDevice(); FixedMultiIndexSet fmset_offdiag {inputDim-1, totalOrder}; return CreateSigmoidExpansionTemplate( diff --git a/tests/MultiIndices/Test_MultiIndexSet.cpp b/tests/MultiIndices/Test_MultiIndexSet.cpp index e06a7cac..d4a76627 100644 --- a/tests/MultiIndices/Test_MultiIndexSet.cpp +++ b/tests/MultiIndices/Test_MultiIndexSet.cpp @@ -161,7 +161,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) { REQUIRE( mset.Size()==((maxOrder+1)*(maxOrder+2)/2)); - MultiIndexSet mset_sep = MultiIndexSet::CreateTotalOrder(dim+1, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder)); + MultiIndexSet mset_sep = MultiIndexSet::CreateTotalOrder(dim+1, maxOrder, MultiIndexLimiter::Separable()); REQUIRE( mset_sep.Size()== ((maxOrder+1)*(maxOrder+2)/2) + maxOrder); REQUIRE( mset_sep.NonzeroDiagonalEntries().size() == maxOrder); @@ -487,7 +487,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) { SECTION("NonzeroDiagonalEntries") { unsigned int dim = 2, maxOrder = 3; - MultiIndexSet set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::SeparableTotalOrder(maxOrder)); + MultiIndexSet set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::Separable()); std::vector inds = set.NonzeroDiagonalEntries(); for(unsigned int ind: inds) { MultiIndex multi = set.at(ind); @@ -499,7 +499,7 @@ TEST_CASE("Testing the MultiIndexSet class", "[MultiIndexSet]" ) { expected_size += multi.HasNonzeroEnd(); } REQUIRE( expected_size == inds.size() ); - MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder)); + MultiIndexSet full_set = MultiIndexSet::CreateTotalOrder(dim, maxOrder, MultiIndexLimiter::NonzeroDiag()); inds = full_set.NonzeroDiagonalEntries(); REQUIRE( inds.size() == full_set.Size() ); for(int i = 0; i < full_set.Size(); ++i) { diff --git a/tests/Test_RectifiedMultivariateExpansion.cpp b/tests/Test_RectifiedMultivariateExpansion.cpp index bada2382..ec1651d6 100644 --- a/tests/Test_RectifiedMultivariateExpansion.cpp +++ b/tests/Test_RectifiedMultivariateExpansion.cpp @@ -22,7 +22,7 @@ TEST_CASE("RectifiedMultivariateExpansion, Unrectified", "[RMVE_NoRect]") { BasisEvaluator basis_eval_offdiag; BasisEvaluator, Identity> basis_eval_diag{dim}; FixedMultiIndexSet fmset_offdiag(dim-1, maxOrder); - auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder); + auto limiter = MultiIndexLimiter::NonzeroDiag(); FixedMultiIndexSet fmset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter).Fix(true); MultivariateExpansionWorker worker_off(fmset_offdiag, basis_eval_offdiag); MultivariateExpansionWorker worker_diag(fmset_diag, basis_eval_diag); @@ -179,7 +179,7 @@ TEMPLATE_TEST_CASE("Single Sigmoid RectifiedMultivariateExpansion","[single_sigm unsigned int maxOrder = 4; FixedMultiIndexSet fmset_offdiag (dim-1, maxOrder); unsigned int sigmoid_order = 4; // const, linear, left ET, right ET, sigmoid - auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(sigmoid_order); + auto limiter = MultiIndexLimiter::NonzeroDiag(); MultiIndexSet mset_diag = MultiIndexSet::CreateTotalOrder(dim, sigmoid_order, limiter); FixedMultiIndexSet fmset_diag = mset_diag.Fix(true); // Setup expansion @@ -282,7 +282,7 @@ TEMPLATE_TEST_CASE("Multiple Sigmoid RectifiedMultivariateExpansion","[multi_sig unsigned int maxOrder = 4; unsigned int dim = 3; FixedMultiIndexSet fmset_offdiag(dim-1, maxOrder); - auto limiter = MultiIndexLimiter::NonzeroDiagTotalOrder(maxOrder); + auto limiter = MultiIndexLimiter::NonzeroDiag(); MultiIndexSet mset_diag = MultiIndexSet::CreateTotalOrder(dim, maxOrder, limiter); FixedMultiIndexSet fmset_diag = mset_diag.Fix(true);