diff --git a/MParT/Sigmoid.h b/MParT/Sigmoid.h index 5c2d0d79..0043ec11 100644 --- a/MParT/Sigmoid.h +++ b/MParT/Sigmoid.h @@ -29,10 +29,10 @@ struct Logistic { }; template -class Sigmoid: public ParameterizedFunctionBase +class Sigmoid1d: public ParameterizedFunctionBase { public: - Sigmoid(StridedVector centers, StridedVector widths): ParameterizedFunctionBase(1, 1, widths.extent(0)) + Sigmoid1d(StridedVector centers, StridedVector widths): ParameterizedFunctionBase(1, 1, widths.extent(0)) { if(centers.extent(0) != widths.extent(0)) { std::stringstream ss; @@ -44,16 +44,35 @@ class Sigmoid: public ParameterizedFunctionBase } void EvaluateImpl(StridedMatrix const& pts, StridedMatrix out) override { - auto policy = Kokkos::RangePolicy::Space>(0,pts.extent(1)); - Kokkos::parallel_for(policy, KOKKOS_CLASS_LAMBDA (unsigned int pointInd) { - out(0,pointInd) = SigmoidType::Evaluate(pts(0,pointInd)); + Kokkos::parallel_for(pts.extent(1), KOKKOS_CLASS_LAMBDA (unsigned int pointInd) { + double eval_pt = 0.; + for(int coeff_index = 0; coeff_index < this->numCoeffs; coeff_index++){ + eval_pt += this->savedCoeffs(coeff_index)*SigmoidType::Evaluate(pts(0,pointInd)); + } + out(0,pointInd) = eval_pt; }); } - void GradientImpl(StridedMatrix const& sens, StridedMatrix const& pts, StridedMatrix out) override {} - void CoeffGradImpl(StridedMatrix const& sens,StridedMatrix const& pts, StridedMatrix out) override {} - private: + void GradientImpl(StridedMatrix const& sens, StridedMatrix const& pts, StridedMatrix out) override { + Kokkos::MDRangePolicy, ExecutionSpace> policy ({0,0},{this->numCoeffs, (int) pts.extent(1)}); + Kokkos::parallel_for(pts.extent(1), KOKKOS_CLASS_LAMBDA(unsigned int sample_index) { + double grad_pt = 0.; + for(int coeff_index = 0; coeff_index < this->numCoeffs; coeff_index++){ + grad_pt += this->savedCoeffs(coeff_index)*SigmoidType::Derivative(pts(0,sample_index)); + } + out(0,sample_index) = sens(0,sample_index)*grad_pt; + }); + } + void CoeffGradImpl(StridedMatrix const& sens,StridedMatrix const& pts, StridedMatrix out) override { + Kokkos::MDRangePolicy, ExecutionSpace> policy ({0,0}, {this->numCoeffs, (int)sens.extent(1)}); + Kokkos::parallel_for(policy, KOKKOS_CLASS_LAMBDA(unsigned int coeff_index, unsigned int sample_index) { + out(coeff_index,sample_index) = sens(0,sample_index)*SigmoidType::Evaluate(pts(0,sample_index)); + }); + } + + private: + using ExecutionSpace = typename MemoryToExecution::Space; }; } diff --git a/tests/Test_Sigmoid.cpp b/tests/Test_Sigmoid.cpp index 54024a2a..8f8ad4f2 100644 --- a/tests/Test_Sigmoid.cpp +++ b/tests/Test_Sigmoid.cpp @@ -6,33 +6,64 @@ using namespace Catch; using MemorySpace = Kokkos::HostSpace; -TEMPLATE_TEST_CASE("Sigmoid","[sigmoid]", Logistic) { +TEMPLATE_TEST_CASE("Sigmoid1d","[sigmoid1d]", Logistic) { SECTION("Initialization") { Kokkos::View centers("Sigmoid Centers", 2); Kokkos::View widths("Sigmoid Centers", 1); - CHECK_THROWS_AS((Sigmoid(centers, widths)), std::invalid_argument); + CHECK_THROWS_AS((Sigmoid1d(centers, widths)), std::invalid_argument); } SECTION("Single Sigmoid") { Kokkos::View center("Sigmoid Centers", 1); Kokkos::View width("Sigmoid Centers", 1); Kokkos::View coeff("Sigmoid coeff", 1); - center(0) = 0; width(0) = 1; coeff(0) = 1.0; - Sigmoid Sigmoid (center, width); - Sigmoid.SetCoeffs(coeff); - Kokkos::View evalPts("Input point", 1, 3); - evalPts(0,0) = -100; evalPts(0,1) = 0.0; evalPts(0,2) = 100; - StridedMatrix out = Sigmoid.Evaluate(evalPts); - double approxTol = 1e-5; - REQUIRE_THAT(out(0,0), Matchers::WithinAbs(0.0, approxTol)); - REQUIRE_THAT(out(0,1), Matchers::WithinAbs(0.5, approxTol)); - REQUIRE_THAT(out(0,2), Matchers::WithinAbs(1.0, approxTol)); + center(0) = 0; width(0) = 1; + Sigmoid1d Sigmoid (center, width); + for(int coeff_int = 1; coeff_int <= 2; coeff_int++) { + coeff(0) = (double) coeff_int; + Sigmoid.WrapCoeffs(coeff); + Kokkos::View evalPts("Input point", 1, 3); + evalPts(0,0) = -100; evalPts(0,1) = 0.0; evalPts(0,2) = 100; + StridedMatrix out = Sigmoid.Evaluate(evalPts); + double approxTol = 1e-5; + REQUIRE_THAT(out(0,0), Matchers::WithinAbs(coeff_int*0.0, approxTol)); + REQUIRE_THAT(out(0,1), Matchers::WithinAbs(coeff_int*0.5, approxTol)); + REQUIRE_THAT(out(0,2), Matchers::WithinAbs(coeff_int*1.0, approxTol)); + } + + unsigned int N_grad_points = 100; + double fd_delta = 1e-5; + Kokkos::View gradPts("Gradient points", 1, N_grad_points); + Kokkos::View gradPts_plus_delta("Gradient points plus delta", 1, N_grad_points); + Kokkos::View sens("Sensitivities", 1, N_grad_points); + Kokkos::parallel_for(N_grad_points, KOKKOS_LAMBDA(unsigned int point_index) { + double gradPt = 3.0*(-1.0 + 2*((double) point_index)/((double) N_grad_points-1)); + double sensPt = 2.0 + ((double) point_index)/((double) N_grad_points-1); + gradPts(0,point_index) = gradPt; + gradPts_plus_delta(0,point_index) = gradPt + fd_delta; + sens(0,point_index) = sensPt; + }); + Kokkos::View input_grad = Sigmoid.Gradient(sens, gradPts); + Kokkos::View coeff_grad = Sigmoid.CoeffGrad(sens, gradPts); + Kokkos::View gradPts_eval = Sigmoid.Evaluate(gradPts); + Kokkos::View gradPts_plus_delta_eval = Sigmoid.Evaluate(gradPts_plus_delta); + coeff(0) += fd_delta; + Kokkos::View gradPts_plus_coeff_delta_eval = Sigmoid.Evaluate(gradPts); + double input_grad_error_accumulator = 0., coeff_grad_error_accumulator = 0.; + for(int i = 0; i < N_grad_points; i++) { + double input_grad_i_fd = (gradPts_plus_delta_eval(0,i) - gradPts_eval(0,i))/fd_delta; + double coeff_grad_i_fd = (gradPts_plus_coeff_delta_eval(0,i) - gradPts_eval(0,i))/fd_delta; + input_grad_error_accumulator += fabs(input_grad_i_fd*sens(0,i) - input_grad(0,i)); + coeff_grad_error_accumulator += fabs(coeff_grad_i_fd*sens(0,i) - coeff_grad(0,i)); + } + CHECK(input_grad_error_accumulator < N_grad_points*2*fd_delta); + CHECK(coeff_grad_error_accumulator < N_grad_points*2*fd_delta); } SECTION("Multiple Sigmoids") { int N_Sigmoid = 3; Kokkos::View centers("Sigmoid Centers", N_Sigmoid); Kokkos::View widths("Sigmoid Centers", N_Sigmoid); - Sigmoid Sigmoid (centers, widths); + Sigmoid1d Sigmoid (centers, widths); Kokkos::View coeffs("Sigmoid Coefficients", N_Sigmoid); for(int j = 0; j < N_Sigmoid; j++) coeffs(j) = 0.5*(j+1); Sigmoid.WrapCoeffs(coeffs);