Skip to content

Commit

Permalink
Improved robustness of error handling in bracketed inverse.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Parno committed Oct 17, 2023
1 parent e4eeb5d commit 23a7604
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 56 deletions.
23 changes: 4 additions & 19 deletions MParT/MonotoneComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
auto functor = KOKKOS_CLASS_LAMBDA (typename Kokkos::TeamPolicy<ExecutionSpace>::member_type team_member) {

unsigned int ptInd = team_member.league_rank () * team_member.team_size () + team_member.team_rank ();
int info;

if(ptInd<numPts){
unsigned int xInd = ptInd;
Expand All @@ -404,38 +405,22 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
// Create a subview containing x_{1:d-1}
auto pt = Kokkos::subview(xs, Kokkos::ALL(), xInd);

//std::cout << "\n\nPt0 = \n" << std::endl;
// Check for NaNs. If found, set output to nan and return
for(unsigned int ii=0; ii<pt.size(); ++ii){
//std::cout << " " << pt(ii);
if(std::isnan(pt(ii))){
std::cout << "Warning: Found nan in point, not computing inverse!" << std::endl;
//throw std::runtime_error("Found nan in point!");
output(ptInd) = std::numeric_limits<double>::quiet_NaN();
return;
}
}
//std::cout << std::endl << std::endl;

// Fill in the cache with everything that doesn't depend on x_d
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
expansion_.FillCache1(cache.data(), pt, DerivativeFlags::None);

// std::cout << "\n\nPt1 = \n" << std::endl;
// double max_val = 0.0;
// for(unsigned int ii=0; ii<pt.size(); ++ii){
// std::cout << " " << pt(ii);
// max_val = (abs(pt(ii))>max_val) ? abs(pt(ii)) : max_val;
// }
// std::cout << std::endl << std::endl;
// if (max_val > 1e-4){
// std::cout << "Max val = " << max_val;
// assert(max_val < 1e4);
// }
// Compute the inverse
Kokkos::View<double*,MemorySpace> workspace(team_member.thread_scratch(1), workspaceSize);
auto eval = SingleEvaluator<decltype(pt),decltype(coeffs)>(workspace.data(), cache.data(), pt, coeffs, quad_, expansion_);
//std::cout << "Here 0" << std::endl;
output(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol);
//std::cout << "Here 1" << std::endl;
output(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol, info);
}
};

Expand Down
3 changes: 3 additions & 0 deletions MParT/Utilities/Miscellaneous.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace mpart{
std::string const& key,
std::string const& defaultValue);

/** Provides a mechanism for raising exceptions in CPU code where recovery is possible
and assertions in GPU code where exceptions aren't alllowed.
*/
template<typename MemorySpace, typename ErrorType>
struct ProcAgnosticError {
static void error(const char*) {
Expand Down
67 changes: 38 additions & 29 deletions MParT/Utilities/RootFinding.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,25 @@ KOKKOS_INLINE_FUNCTION void swapPair(T& x1, T& x2, T& y1, T& y2) {
simple_swap(y1, y2);
}

/** Finds a bracket [xlb, xub] such that f(xlb)<yd and f(xub)>yd. */
/** Finds a bracket [xlb, xub] such that f(xlb)<yd and f(xub)>yd.
* The info argument can be used to detect when a bracket cannot be found. Upon exit, a value of info=0
* indicates success while a negative value indicates failure. info=-1 indicates that the function
* seems to be perfectly flat and a root might not exist. info=-2 indicates that the maximum number of
* iterations (128) was exceeded.
*/
template<typename MemorySpace, typename FunctorType>
KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
double& xlb, double& ylb,
double& xub, double& yub,
const double yd)
const double yd,
int& info)
{
const unsigned int maxIts = 40;
const unsigned int maxIts = 128;
double stepSize = 1.0;

info = 0;

ylb = f(xlb);
yub = f(xub);
//std::cout << "-1" << ": " << xlb << ", " << ylb << ", " << xub << ", " << yub << " vs " << yd << std::endl;


// We actually found an upper bound...
if(ylb>yd){
Expand All @@ -37,12 +42,10 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xlb = xub-stepSize;
ylb = f(xlb);
//std::cout << "\n\n1 " << i << ": " << xlb << ", " << ylb << ", " << xub << ", " << yub << " vs " << yd << std::endl;

if(abs((yub-ylb)/(xub-xlb))<1e-12){
info = -1;
break;
//std::cout << "slope = " << (yub-ylb)/(xub-xlb) << std::endl;
//assert(abs(yub-ylb)>1e-10);
}

if(ylb>yd){
Expand All @@ -54,9 +57,8 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
}
}


// if(i>=maxIts)
// ProcAgnosticError<MemorySpace,std::runtime_error>::error("FindBracket: Could not find initial bracket such that f(xlb)<yd and f(xub)>yd.");
if(i>=maxIts)
info = -2;

// We have a lower bound...
}else{
Expand All @@ -65,13 +67,13 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xub = xlb+stepSize;
yub = f(xub);
//std::cout << "\n\n2 " << i << ": " << xlb << ", " << ylb << ", " << xub << ", " << yub << " vs " << yd << std::endl;

// Check to see if function is perfectly flat
if(abs((yub-ylb)/(xub-xlb))<1e-12){
info = -1;
break;
}

// if(abs((yub-ylb)/(xub-xlb))<1e-10){
// std::cout << "Function seems to be flat!" << std::endl;
// std::cout << "slope = " << (yub-ylb)/(xub-xlb) << std::endl;
// assert(abs(yub-ylb)>1e-10);
// }
if(yub<yd){
ylb = yub;
xlb = xub;
Expand All @@ -81,13 +83,14 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f,
}
}

// if(i>=maxIts)
// ProcAgnosticError<MemorySpace,std::runtime_error>::error("FindBracket: Could not find initial bracket such that f(xlb)<yd and f(xub)>yd.");
if(i>=maxIts){
info = -2;
}
}
}

KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, double ylb, double yub,
double k1, double k2, double nhalf, double n0, int it, double xtol) {
double k1, double k2, double nhalf, double n0, int it, double xtol) {

double xb = 0.5*(xub+xlb); // bisection point
double xf = (xub*ylb - xlb*yub)/(ylb-yub); // regula-falsi point
Expand All @@ -102,8 +105,13 @@ KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, doub
return xc;
}

/** Computes the inverse of a function using the ITP method.
* The info argument will be 0 upon successful completion and negative for failed inversions.
* A value of info=-2 indicates a failure to find a bracket that contains the root. In this case, a nan will be returned.
* A value of info=-1 indicates that the maximum number of iterations was exceeded.
*/
template<typename MemorySpace, typename FunctorType>
KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol)
KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol, int& info)
{
double stepSize=1.0;
const unsigned int maxIts = 10000;
Expand All @@ -112,16 +120,18 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou
double xlb, xub;
double ylb, yub;
double xc, yc;
info = 0;

//std::cout << "xlb = " << xlb << std::endl;
xlb = xub = x0;
ylb = yub = f(xlb);

// Compute bounds
FindBracket<MemorySpace>(f, xlb, ylb, xub, yub, yd);

if ((ylb>yd)||(yub<yd))
// Compute initial bracket containing the root
int bracket_info = 0;
FindBracket<MemorySpace>(f, xlb, ylb, xub, yub, yd, bracket_info);
if((bracket_info<0)||(((ylb>yd)||(yub<yd)))){
info = -2;
return std::numeric_limits<double>::quiet_NaN();
}

// Bracketed search
const double k1 = 0.1;
Expand Down Expand Up @@ -151,8 +161,7 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou


if(it>maxIts)
return std::numeric_limits<double>::quiet_NaN();
//ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: Bracket search iterations exceeds maxIts");
info = -1;

return 0.5*(xub+xlb);
}
Expand Down
51 changes: 43 additions & 8 deletions tests/Test_RootFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,52 +30,87 @@ TEST_CASE( "RootFindingUtils", "[RootFindingUtils]") {
double yd = identity(xd);
double xub = 2., yub = 2.;
double xlb = 0., ylb = 0.;
FindBracket<HostSpace>(identity, xlb, ylb, xub, yub, yd);
int info = 0;
FindBracket<HostSpace>(identity, xlb, ylb, xub, yub, yd, info);
CheckFoundBounds(identity, xlb, xd, xub, ylb, yd, yub);
CHECK(info==0);
}
SECTION("FindBracker sigmoid") {
double xd = -0.5;
double yd = sigmoid(xd);
double xub = 2., yub = sigmoid(xub);
double xlb = 0., ylb = sigmoid(xlb);
FindBracket<HostSpace>(sigmoid, xlb, ylb, xub, yub, yd);
int info = 0;
FindBracket<HostSpace>(sigmoid, xlb, ylb, xub, yub, yd, info);
CheckFoundBounds(sigmoid, xlb, xd, xub, ylb, yd, yub);
CHECK(info==0);
}
SECTION("FindBracket flat") {
double xd = -1.1;
double yd = 0.0;
double xub = 2., yub = -1.0;
double xlb = 0., ylb = -1.0;
int info = 0;
auto f = [](double x){return -1.0;};
FindBracket<HostSpace>(f, xlb, ylb, xub, yub, yd, info);
CHECK(info==-1);
}
SECTION("Test Inverse Linear, low x0") {
double xd = 0.5, yd = identity(xd);
double x0 = 0.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, identity, x0, xtol, ftol);
int info = 0;
double xd_found = InverseSingleBracket<HostSpace>(yd, identity, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}
SECTION("Test Inverse Linear, high x0") {
double xd = 0.5, yd = identity(xd);
double x0 = 1.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, identity, x0, xtol, ftol);
int info = 0;
double xd_found = InverseSingleBracket<HostSpace>(yd, identity, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}
SECTION("Test Inverse Flat") {
double xd = 0.5, yd = -1.0;
double x0 = 0.0, xtol = 1e-5, ftol = 1e-5;
auto f = [](double x){return -1.0;};
int info = 0;
double xd_found = InverseSingleBracket<HostSpace>(yd, f, x0, xtol, ftol, info);
CHECK( std::isnan(xd_found));
CHECK(info==-2);
}
SECTION("Test Inverse Sigmoid, low x0") {
double xd = 0.5, yd = sigmoid(xd);
double x0 = 0.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid, x0, xtol, ftol);
int info = 0;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}
SECTION("Test Inverse Sigmoid, high x0") {
double xd = 0.5, yd = sigmoid(xd);
double x0 = 1.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid, x0, xtol, ftol);
int info;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}
SECTION("Test Inverse Sigmoid Combo, low x0") {
double xd = 0.5, yd = sigmoid_combo(xd);
double x0 = -5.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid_combo, x0, xtol, ftol);
int info;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid_combo, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}
SECTION("Test Inverse Sigmoid Combo, high x0") {
double xd = 0.5, yd = sigmoid_combo(xd);
double x0 = 5.0, xtol = 1e-5, ftol = 1e-5;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid_combo, x0, xtol, ftol);
int info;
double xd_found = InverseSingleBracket<HostSpace>(yd, sigmoid_combo, x0, xtol, ftol, info);
CHECK( xd_found == Approx(xd).epsilon(2*xtol));
CHECK(info==0);
}

}

0 comments on commit 23a7604

Please sign in to comment.