Skip to content

Commit

Permalink
Now throwing nan if inverse blows up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Parno committed Oct 14, 2023
1 parent 6b6adc1 commit e4eeb5d
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 124 deletions.
270 changes: 160 additions & 110 deletions MParT/MonotoneComponent.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,16 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
StridedVector<double, MemorySpace> output,
std::map<std::string, std::string> options=std::map<std::string,std::string>())
{
// std::cout << "xs.shape = " << xs.extent(0) << "," << xs.extent(1) << std::endl;
// std::cout << "xs = \n" << std::endl;
// for(unsigned int row=0; row<xs.extent(0); ++row){
// for(unsigned int col=0; col<xs.extent(1); ++col){
// std::cout << " " << xs(row,col);
// }
// std::cout << std::endl;
// }
// std::cout << std::endl;

// Extract the method from the options map
std::string method;
if(options.count("Method")){
Expand Down Expand Up @@ -393,15 +403,39 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>

// Create a subview containing x_{1:d-1}
auto pt = Kokkos::subview(xs, Kokkos::ALL(), xInd);

//std::cout << "\n\nPt0 = \n" << std::endl;
for(unsigned int ii=0; ii<pt.size(); ++ii){
//std::cout << " " << pt(ii);
if(std::isnan(pt(ii))){
std::cout << "Warning: Found nan in point, not computing inverse!" << std::endl;
//throw std::runtime_error("Found nan in point!");
return;
}
}
//std::cout << std::endl << std::endl;

// Fill in the cache with everything that doesn't depend on x_d
Kokkos::View<double*,MemorySpace> cache(team_member.thread_scratch(1), cacheSize);
expansion_.FillCache1(cache.data(), pt, DerivativeFlags::None);

// std::cout << "\n\nPt1 = \n" << std::endl;
// double max_val = 0.0;
// for(unsigned int ii=0; ii<pt.size(); ++ii){
// std::cout << " " << pt(ii);
// max_val = (abs(pt(ii))>max_val) ? abs(pt(ii)) : max_val;
// }
// std::cout << std::endl << std::endl;
// if (max_val > 1e-4){
// std::cout << "Max val = " << max_val;
// assert(max_val < 1e4);
// }
// Compute the inverse
Kokkos::View<double*,MemorySpace> workspace(team_member.thread_scratch(1), workspaceSize);
auto eval = SingleEvaluator<decltype(pt),decltype(coeffs)>(workspace.data(), cache.data(), pt, coeffs, quad_, expansion_);
//std::cout << "Here 0" << std::endl;
output(ptInd) = RootFinding::InverseSingleBracket<MemorySpace>(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol);
//std::cout << "Here 1" << std::endl;
}
};

Expand Down Expand Up @@ -967,6 +1001,16 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
QuadratureType const& quad,
ExpansionType const& expansion)
{
// std::cout << "pt = " << std::endl;
// for(unsigned int i=0; i<pt.size(); ++i)
// std::cout << " " << pt(i);
// std::cout << std::endl;
// std::cout << "xd = " << xd << std::endl;
// std::cout << "Coefficients = " << std::endl;
// for(unsigned int i=0; i<coeffs.size(); ++i)
// std::cout << " " << coeffs(i);
// std::cout << std::endl;

double output = 0.0;
// Compute the integral \int_0^1 g( \partial_D f(x_1,...,x_{D-1},t*x_d)) dt
MonotoneIntegrand<ExpansionType, PosFuncType, PointType, CoeffsType, MemorySpace> integrand(cache,
Expand All @@ -977,11 +1021,17 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
DerivativeFlags::None);

quad.Integrate(workspace, integrand, 0, 1, &output);
// std::cout << "output 1 = " << output << std::endl;

// Finish filling in the cache for an evaluation of the expansion with x_d=0
// std::cout << "pt = " << std::endl;
// for(unsigned int i=0; i<pt.size(); ++i)
// std::cout << " " << pt(i);
// std::cout << std::endl;

expansion.FillCache2(cache, pt, 0.0, DerivativeFlags::None);
output += expansion.Evaluate(cache, coeffs);

// std::cout << "output 2 = " << output << std::endl;
return output;
}

Expand All @@ -1003,115 +1053,115 @@ class MonotoneComponent : public ConditionalMapBase<MemorySpace>
@tparam PointType The type of the point. Typically either Kokkos::View<double*> or some subview with a similar 1d signature.
@tparam CoeffsType The type of the coefficients. Typically Kokkos::View<double*> or a similarly structured subview.
*/
template<typename PointType, typename CoeffsType>
KOKKOS_FUNCTION static double InverseSingleBracket(double* cache,
double* workspace,
PointType const& pt,
double yd,
CoeffsType const& coeffs,
const double xtol,
const double ftol,
QuadratureType const& quad,
ExpansionType const& expansion)
{
double stepSize=1.0;
const unsigned int maxIts = 10000;

// First, we need to find two points that bound the solution.
double xlb, xub;
double ylb, yub;
double xb, xf; // Bisection point and regula falsi point
double xc, yc;

xlb = pt(pt.extent(0)-1);
ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);

// We actually found an upper bound...
if(ylb>yd){

mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);

// Now find a lower bound...
unsigned int i;
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xlb = xub-stepSize;
ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);
if(ylb>yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
stepSize *= 2.0;
}else{
break;
}
}
if(i>maxIts)
ProcAgnosticError<MemorySpace, std::runtime_error>::error("InverseSingleBracket: lower bound iterations exceed maxIts");

// We have a lower bound...
}else{
// Now find an upper bound...
unsigned int i;
for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
xub = xlb+stepSize;
yub = EvaluateSingle(workspace, cache, pt, xub, coeffs, quad, expansion);
if(yub<yd){
mpart::simple_swap(ylb,yub);
mpart::simple_swap(xlb,xub);
stepSize *= 2.0;
}else{
break;
}
}
if(i>maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: upper bound calculation exceeds maxIts");
}

assert(ylb<yub);
assert(xlb<xub);

// Bracketed search
const double k1 = 0.1;
const double k2 = 2.0;
const double nhalf = ceil(log2(0.5*(xub-xlb)/xtol));
const double n0 = 1.0;

double sigma, delta, rho;
unsigned int it;
for(it=0; it<maxIts; ++it){

xb = 0.5*(xub+xlb); // bisection point
xf = xlb - (yd-ylb)*(xub-xlb) / (yub-ylb); // regula-falsi point

sigma = ((xb-xf)>0)?1.0:-1.0; // sign(xb-xf)
delta = fmin(k1*pow((xub-xlb), k2), fabs(xb-xf));

xf += delta*sigma;

rho = fmin(xtol*pow(2.0, nhalf + n0 - it) - 0.5*(xub-xlb), fabs(xf - xb));
xc = xb - sigma*rho;

yc = EvaluateSingle(workspace, cache, pt, xc, coeffs, quad, expansion);

if(abs(yc-yd)<ftol){
return xc;
}else if(yc>yd){
mpart::simple_swap(yc,yub);
mpart::simple_swap(xc,xub);
}else{
mpart::simple_swap(yc,ylb);
mpart::simple_swap(xc,xlb);
}

// Check for convergence
if(((xub-xlb)<xtol)||((yub-ylb)<ftol))
break;
};

if(it>maxIts)
ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: Bracket search iterations exceeds maxIts");
return 0.5*(xub+xlb);
}
// template<typename PointType, typename CoeffsType>
// KOKKOS_FUNCTION static double InverseSingleBracket(double* cache,
// double* workspace,
// PointType const& pt,
// double yd,
// CoeffsType const& coeffs,
// const double xtol,
// const double ftol,
// QuadratureType const& quad,
// ExpansionType const& expansion)
// {
// double stepSize=1.0;
// const unsigned int maxIts = 10000;

// // First, we need to find two points that bound the solution.
// double xlb, xub;
// double ylb, yub;
// double xb, xf; // Bisection point and regula falsi point
// double xc, yc;

// xlb = pt(pt.extent(0)-1);
// ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);

// // We actually found an upper bound...
// if(ylb>yd){

// mpart::simple_swap(ylb,yub);
// mpart::simple_swap(xlb,xub);

// // Now find a lower bound...
// unsigned int i;
// for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
// xlb = xub-stepSize;
// ylb = EvaluateSingle(workspace, cache, pt, xlb, coeffs, quad, expansion);
// if(ylb>yd){
// mpart::simple_swap(ylb,yub);
// mpart::simple_swap(xlb,xub);
// stepSize *= 2.0;
// }else{
// break;
// }
// }
// if(i>maxIts)
// ProcAgnosticError<MemorySpace, std::runtime_error>::error("InverseSingleBracket: lower bound iterations exceed maxIts");

// // We have a lower bound...
// }else{
// // Now find an upper bound...
// unsigned int i;
// for(i=0; i<maxIts; ++i){ // Could just be while(true), but want to avoid infinite loop
// xub = xlb+stepSize;
// yub = EvaluateSingle(workspace, cache, pt, xub, coeffs, quad, expansion);
// if(yub<yd){
// mpart::simple_swap(ylb,yub);
// mpart::simple_swap(xlb,xub);
// stepSize *= 2.0;
// }else{
// break;
// }
// }
// if(i>maxIts)
// ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: upper bound calculation exceeds maxIts");
// }

// assert(ylb<yub);
// assert(xlb<xub);

// // Bracketed search
// const double k1 = 0.1;
// const double k2 = 2.0;
// const double nhalf = ceil(log2(0.5*(xub-xlb)/xtol));
// const double n0 = 1.0;

// double sigma, delta, rho;
// unsigned int it;
// for(it=0; it<maxIts; ++it){

// xb = 0.5*(xub+xlb); // bisection point
// xf = xlb - (yd-ylb)*(xub-xlb) / (yub-ylb); // regula-falsi point

// sigma = ((xb-xf)>0)?1.0:-1.0; // sign(xb-xf)
// delta = fmin(k1*pow((xub-xlb), k2), fabs(xb-xf));

// xf += delta*sigma;

// rho = fmin(xtol*pow(2.0, nhalf + n0 - it) - 0.5*(xub-xlb), fabs(xf - xb));
// xc = xb - sigma*rho;

// yc = EvaluateSingle(workspace, cache, pt, xc, coeffs, quad, expansion);

// if(abs(yc-yd)<ftol){
// return xc;
// }else if(yc>yd){
// mpart::simple_swap(yc,yub);
// mpart::simple_swap(xc,xub);
// }else{
// mpart::simple_swap(yc,ylb);
// mpart::simple_swap(xc,xlb);
// }

// // Check for convergence
// if(((xub-xlb)<xtol)||((yub-ylb)<ftol))
// break;
// };

// if(it>maxIts)
// ProcAgnosticError<MemorySpace,std::runtime_error>::error("InverseSingleBracket: Bracket search iterations exceeds maxIts");
// return 0.5*(xub+xlb);
// }

/** Give access to the underlying FixedMultiIndexSet
* @return The FixedMultiIndexSet
Expand Down
2 changes: 1 addition & 1 deletion MParT/Utilities/Miscellaneous.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace mpart{
template<typename MemorySpace, typename ErrorType>
struct ProcAgnosticError {
static void error(const char*) {
assert(0);
assert(false);
}
};

Expand Down
Loading

0 comments on commit e4eeb5d

Please sign in to comment.