Skip to content

Commit

Permalink
Merge branch 'main' into sort-multis
Browse files Browse the repository at this point in the history
  • Loading branch information
mparno authored Jan 31, 2024
2 parents 07966b0 + 6a343af commit 898034e
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 110 deletions.
55 changes: 45 additions & 10 deletions MParT/BasisEvaluator.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <Kokkos_Core.hpp>
#include "MParT/PositiveBijectors.h"

namespace mpart {
/**
Expand All @@ -7,27 +8,52 @@ namespace mpart {
*
* Assuming you want create a function
* \f$f_{\vec{\alpha}}:\mathbb{R}^{d+1}\to\mathbb{R}\f$ which has multi-index
* \f$\vec{\alpha}\f$, there are three possibilities:
* \f$\vec{\alpha}\f$, there are a few possibilities:
*
* - All diagonal and offdiagonal univariate functions are identical, i.e.
* \f$f(x_1,\ldots,x_d,y)=\psi_{\alpha_{d+1}}(y)\prod_{j=1}^d\psi_{\alpha_j}(x_j)\f$,
* which is `Homogeneous`
* - All offdiagonal univariate functions are identical, i.e.
* \f$f(x_1,\ldots,x_d,y)=\psi^{diag}_{\alpha_{d+1}}(y)\prod_{j=1}^d\psi^{offdiag}_{\alpha_j}(x_j)\f$,
* which is `OffdiagHomogeneous`
* which is `OffdiagHomogeneous`.
* - All univariate basis functions may be different, i.e.
* \f$f(x_1,\ldots,x_d,y)=\psi^{d+1}_{\alpha_{d+1}}(y)\prod_{j=1}^d\psi^{j}_{\alpha_j}(x_j)\f$
* which is `Heterogeneous`
*/
enum BasisHomogeneity { Homogeneous, OffdiagHomogeneous, Heterogeneous };

/**
* @brief Defines the identity function \f$g(x) = x\f$.
*/
class Identity{
public:

KOKKOS_INLINE_FUNCTION static double Evaluate(double x){
//stable implementation of std::log(1.0 + std::exp(x)) for large values
return x;
}

KOKKOS_INLINE_FUNCTION static double Derivative(double x){
return 1.0;
}

KOKKOS_INLINE_FUNCTION static double SecondDerivative(double x){
return 0.;
}

KOKKOS_INLINE_FUNCTION static double Inverse(double x){
return x;
}

};

/**
* @brief Class to represent all elements of a multivariate function basis
*
* See BasisHomogeneity for information on options for \c HowHomogeneous .
* The form of template parameter \c BasisEvaluatorType will depend on \c
* HowHomogeneous See the documentation of each implementation for details on
* what's necessary.
* what's necessary. We give the option to "rectify" the function if it depends on y (i.e. make the part dependent on x positive)
*
* Any univariate basis function used here must have the following functions:
*
Expand All @@ -42,8 +68,9 @@ enum BasisHomogeneity { Homogeneous, OffdiagHomogeneous, Heterogeneous };
* BasisHomogeneity for more info)
* @tparam BasisEvaluatorType The type we need to evaluate when evaluating the
* basis
* @tparam RectifierType The rectification operator for diag/offdiag cross-terms
*/
template <BasisHomogeneity HowHomogeneous, typename BasisEvaluatorType>
template <BasisHomogeneity HowHomogeneous, typename BasisEvaluatorType, typename RectifierType=Identity>
class BasisEvaluator {
public:
/**
Expand All @@ -68,7 +95,6 @@ class BasisEvaluator {
int maxOrder, double point) const {
assert(false);
}
// EvaluateDerivatives(dim, output_eval, output_deriv, max_order, input)

/**
* @brief Evaluate the functions for the multivariate basis
Expand Down Expand Up @@ -116,6 +142,14 @@ class BasisEvaluator {
#endif
};

template<typename T>
struct GetRectifier{};

template<BasisHomogeneity HowHomogeneous, typename BasisEvaluatorType, typename RectifierType>
struct GetRectifier<BasisEvaluator<HowHomogeneous, BasisEvaluatorType, RectifierType>>{
using type = RectifierType;
};

/**
* @brief Basis evaluator when all univariate basis fcns are identical
*
Expand Down Expand Up @@ -179,10 +213,12 @@ class BasisEvaluator<BasisHomogeneity::Homogeneous, BasisEvaluatorType> {
*
* @tparam OffdiagEvaluatorType Type to eval offdiagonal univariate basis
* @tparam DiagEvaluatorType Type to eval diagonal univariate basis
* @tparam RectifyType Whether the basis is rectified or not (i.e. positivize the non-diagonal part of cross terms)
*/
template <typename OffdiagEvaluatorType, typename DiagEvaluatorType>
template <typename OffdiagEvaluatorType, typename DiagEvaluatorType, typename Rectifier>
class BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous,
Kokkos::pair<OffdiagEvaluatorType, DiagEvaluatorType>> {
Kokkos::pair<OffdiagEvaluatorType, DiagEvaluatorType>,
Rectifier> {
public:
BasisEvaluator(
int dim,
Expand Down Expand Up @@ -242,7 +278,6 @@ class BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous,
#endif
/// @brief Number of input dimensions for multivariate basis
int dim_;

OffdiagEvaluatorType offdiag_;
DiagEvaluatorType diag_;
};
Expand All @@ -252,9 +287,9 @@ class BasisEvaluator<BasisHomogeneity::OffdiagHomogeneous,
*
* @tparam CommonBasisEvaluatorType Supertype to univariate basis eval types
*/
template <typename CommonBasisEvaluatorType>
template <typename CommonBasisEvaluatorType, typename Rectifier>
class BasisEvaluator<BasisHomogeneity::Heterogeneous,
std::vector<std::shared_ptr<CommonBasisEvaluatorType>>> {
std::vector<std::shared_ptr<CommonBasisEvaluatorType>>, Rectifier> {
public:
BasisEvaluator(
int dim,
Expand Down
Loading

0 comments on commit 898034e

Please sign in to comment.