Skip to content

Commit

Permalink
Merge branch 'mparno/inverse-bound-fix' of github.com:MeasureTranspor…
Browse files Browse the repository at this point in the history
…t/MParT into mparno/python-serialization
  • Loading branch information
Matthew Parno committed Oct 18, 2023
2 parents 3ca0338 + 2809871 commit 737592c
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 169 deletions.
144 changes: 12 additions & 132 deletions MParT/MonotoneComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
auto functor = KOKKOS_CLASS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {

unsigned int ptInd = team_member.league_rank () * team_member.team_size () + team_member.team_rank ();
int info;

if(ptInd<numPts){
unsigned int xInd = ptInd;
Expand All @@ -410,15 +411,23 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>

// Create a subview containing x_{1:d-1}
auto pt = Kokkos::subview(xs, Kokkos::ALL(), xInd);

// Check for NaNs. If found, set output to nan and return
for(unsigned int ii=0; ii<pt.size(); ++ii){
if(std::isnan(pt(ii))){
output(ptInd) = std::numeric_limits<double>::quiet_NaN();
return;
}
}

// Fill in the cache with everything that doesn't depend on x_d
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
expansion_.FillCache1(cache.data(), pt, DerivativeFlags::None);

// Compute the inverse
Kokkos::View<double*,MemorySpace> workspace(team_member.thread_scratch(1), workspaceSize);
auto eval = SingleEvaluator<decltype(pt), decltype(coeffs)>(workspace.data(), cache.data(), pt, coeffs, quad_, expansion_);
output(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol);
auto eval = SingleEvaluator<decltype(pt),decltype(coeffs)>(workspace.data(), cache.data(), pt, coeffs, quad_, expansion_);
output(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol, info);
}
};

Expand Down Expand Up @@ -995,141 +1004,12 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>

quad.Integrate(workspace, integrand, 0, 1, &output);

// Finish filling in the cache for an evaluation of the expansion with x_d=0
expansion.FillCache2(cache, pt, 0.0, DerivativeFlags::None);
output += expansion.Evaluate(cache, coeffs);

return output;
}


/**
@brief Solves \f$y_D = T(x_1,\ldots,x_D)\f$ for \f$x_D\f$ at a single point using the ITP bracketing method.
@details Uses the [Interpolate Truncate and Project (ITP)](https://en.wikipedia.org/wiki/ITP_method) method to solve
\f$y_D = T(x_1,\ldots,x_D)\f$ for \f$x_D\f$. This method is a bracketing method similar to bisection, but
with faster convergence. The worst case performance of ITP, in terms of required evaluations of \f$T\f$,
is the same as bisection.
@param cache Memory set up by Kokkos for this evaluation. The `expansion_.FillCache1` function must be
called before calling this function.
@param pt Array of length \f$D\f$ containing the fixed values of \f$x_{1:D-1}\f$ and an initial guess for \f$x_D\f$ in the last component.
@param coeffs An array of coefficients for the expansion.
@param options
@return The value of \f$x_D\f$ solving \f$y_D = T(x_1,\ldots,x_D)\f$
@tparam PointType The type of the point. Typically either Kokkos::View<double*> or some subview with a similar 1d signature.
@tparam CoeffsType The type of the coefficients. Typically Kokkos::View<double*> or a similarly structured subview.
*/
template<typename PointType, typename CoeffsType>
KOKKOS_FUNCTION static double InverseSingleBracket(double* cache,
double* workspace,
PointType const& pt,
double yd,
CoeffsType const& coeffs,
const double xtol,
const double ftol,
QuadratureType const& quad,
ExpansionType const& expansion)
{
double stepSize=1.0;
const unsigned int maxIts = 10000;

// First, we need to find two points that bound the solution.
double xlb, xub;
double ylb, yub;
double xb, xf; // Bisection point and regula falsi point
double xc, yc;

xlb = pt(pt.extent(0)-1);
ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);

// We actually found an upper bound...
if(ylb>yd){

mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);

// Now find a lower bound...
unsigned int i;
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xlb = xub-stepSize;
ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);
if(ylb>yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
stepSize *= 2.0;
}else{
break;
}
}
if(i>maxIts)
ProcAgnosticError<MemorySpace, std::runtime_error>::error("InverseSingleBracket: lower bound iterations exceed maxIts");

// We have a lower bound...
}else{
// Now find an upper bound...
unsigned int i;
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xub = xlb+stepSize;
yub = EvaluateSingle(workspace, cache, pt, xub, coeffs, quad, expansion);
if(yub<yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
stepSize *= 2.0;
}else{
break;
}
}
if(i>maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: upper bound calculation exceeds maxIts");
}

assert(ylb<yub);
assert(xlb<xub);

// Bracketed search
const double k1 = 0.1;
const double k2 = 2.0;
const double nhalf = ceil(log2(0.5*(xub-xlb)/xtol));
const double n0 = 1.0;

double sigma, delta, rho;
unsigned int it;
for(it=0; it<maxIts; ++it){

xb = 0.5*(xub+xlb); // bisection point
xf = xlb - (yd-ylb)*(xub-xlb) / (yub-ylb); // regula-falsi point

sigma = ((xb-xf)>0)?1.0:-1.0; // sign(xb-xf)
delta = fmin(k1*pow((xub-xlb), k2), fabs(xb-xf));

xf += delta*sigma;

rho = fmin(xtol*pow(2.0, nhalf + n0 - it) - 0.5*(xub-xlb), fabs(xf - xb));
xc = xb - sigma*rho;

yc = EvaluateSingle(workspace, cache, pt, xc, coeffs, quad, expansion);

if(abs(yc-yd)<ftol){
return xc;
}else if(yc>yd){
mpart::simple_swap(yc,yub);
mpart::simple_swap(xc,xub);
}else{
mpart::simple_swap(yc,ylb);
mpart::simple_swap(xc,xlb);
}

// Check for convergence
if(((xub-xlb)<xtol)||((yub-ylb)<ftol))
break;
};

if(it>maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: Bracket search iterations exceeds maxIts");
return 0.5*(xub+xlb);
}

/** Give access to the underlying FixedMultiIndexSet
* @return The FixedMultiIndexSet
*/
Expand Down
5 changes: 4 additions & 1 deletion MParT/Utilities/Miscellaneous.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ namespace mpart{
std::string const& key,
std::string const& defaultValue);

/** Provides a mechanism for raising exceptions in CPU code where recovery is possible
and assertions in GPU code where exceptions aren't alllowed.
*/
template<typename MemorySpace, typename ErrorType>
struct ProcAgnosticError {
static void error(const char*) {
assert(0);
assert(false);
}
};

Expand Down
81 changes: 54 additions & 27 deletions MParT/Utilities/RootFinding.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,29 @@ KOKKOS_INLINE_FUNCTION void swapPair(T& x1, T& x2, T& y1, T& y2) {
simple_swap(y1, y2);
}

/** Finds a bracket [xlb, xub] such that f(xlb)<yd and f(xub)>yd. */
/** Finds a bracket [xlb, xub] such that f(xlb)<yd and f(xub)>yd.
* The info argument can be used to detect when a bracket cannot be found. Upon exit, a value of info=0
* indicates success while a negative value indicates failure. info=-1 indicates that the function
* seems to be perfectly flat and a root might not exist. info=-2 indicates that the maximum number of
* iterations (128) was exceeded.
*/
template<typename MemorySpace, typename FunctorType>
KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
double& xlb, double& ylb,
double& xub, double& yub,
const double yd)
const double yd,
int& info)
{
double xb, xf; // Bisection point and regula falsi point
double xc, yc;
const unsigned int maxIts = 1000;
const unsigned int maxIts = 128;
double stepSize = 1.0;
info = 0;

ylb = f(xb);
ylb = f(xlb);
yub = f(xub);

// We actually found an upper bound...
if(ylb>yd){

mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);

Expand All @@ -36,17 +42,23 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xlb = xub-stepSize;
ylb = f(xlb);

if((fabs((yub-ylb)/(xub-xlb))<1e-12)&&((xub-xlb)>10)){
info = -1;
return;
}

if(ylb>yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
yub = ylb;
xub = xlb;
stepSize *= 2.0;
}else{
break;
}
}

if(i>=maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("FindBracket: Could not find initial bracket such that f(xlb)<yd and f(xub)>yd.");
info = -2;

// We have a lower bound...
}else{
Expand All @@ -55,22 +67,30 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xub = xlb+stepSize;
yub = f(xub);

// Check to see if function is perfectly flat over a wide region
if((fabs((yub-ylb)/(xub-xlb))<1e-12)&&((xub-xlb)>10)){
info = -1;
return;
}

if(yub<yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
ylb = yub;
xlb = xub;
stepSize *= 2.0;
}else{
break;
}
}

if(i>=maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("FindBracket: Could not find initial bracket such that f(xlb)<yd and f(xub)>yd.");
if(i>=maxIts){
info = -2;
}
}
}
}

KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, double ylb, double yub,
double k1, double k2, double nhalf, double n0, int it, double xtol) {
double k1, double k2, double nhalf, double n0, int it, double xtol) {

double xb = 0.5*(xub+xlb); // bisection point
double xf = (xub*ylb - xlb*yub)/(ylb-yub); // regula-falsi point
Expand All @@ -85,27 +105,34 @@ KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, doub
return xc;
}

/** Computes the inverse of a function using the ITP method.
* The info argument will be 0 upon successful completion and negative for failed inversions.
* A value of info=-2 indicates a failure to find a bracket that contains the root. In this case, a nan will be returned.
* A value of info=-1 indicates that the maximum number of iterations was exceeded.
*/
template<typename MemorySpace, typename FunctorType>
KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol)
KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol, int& info)
{
std::cout << "Hereherehere" << std::endl;
double stepSize=1.0;
const unsigned int maxIts = 10000;

// First, we need to find two points that bound the solution.
double xlb, xub;
double ylb, yub;
double xc, yc;
info = 0;

xlb = xub = x0;
ylb = yub = f(xlb);

// Compute bounds
std::cout << "About to call find bracket..." << std::endl;
FindBracket<MemorySpace>(f, xlb, ylb, xub, yub, yd);
std::cout << "done " << xlb << ", " << xub << ", " << ylb << ", " << yub << std::endl;
assert(ylb<yub);
assert(xlb<xub);
// Compute initial bracket containing the root
int bracket_info = 0;
FindBracket<MemorySpace>(f, xlb, ylb, xub, yub, yd, bracket_info);

if((bracket_info<0)||(((ylb>yd)||(yub<yd)))){
info = -2;
return std::numeric_limits<double>::quiet_NaN();
}

// Bracketed search
const double k1 = 0.1;
Expand All @@ -115,9 +142,8 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou

unsigned int it;
for(it=0; it<maxIts; ++it){
std::cout << " about to call Find_x_ITP" << std::endl;

xc = Find_x_ITP(xlb, xub, yd, ylb, yub, k1, k2, nhalf, n0, it, xtol);
std::cout << " done" << std::endl;

yc = f(xc);

Expand All @@ -133,8 +159,9 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou
if(((xub-xlb)<xtol)||((yub-ylb)<ftol)) break;
};


if(it>maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: Bracket search iterations exceeds maxIts");
info = -1;

return 0.5*(xub+xlb);
}
Expand Down
Loading

0 comments on commit 737592c

Please sign in to comment.