Skip to content

Commit 276be49

Browse files
skyelvesfacebook-github-bot
authored andcommitted
feat: Add inverse_binomial_cdf (facebookincubator#12983)
Summary: Pull Request resolved: facebookincubator#12983 feat: Add inverse_binomial_cdf Reviewed By: amitkdutta Differential Revision: D72749280 fbshipit-source-id: ef11432ee80f2034a45411512c4d025ef55a90a0
1 parent f7cacf3 commit 276be49

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

velox/expression/fuzzer/ExpressionFuzzerTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ int main(int argc, char** argv) {
221221
"map_keys_by_top_n_values", // requires
222222
// https://github.com/prestodb/presto/pull/24570
223223
"inverse_gamma_cdf", // https://github.com/facebookincubator/velox/issues/12918
224+
"inverse_binomial_cdf", // https://github.com/facebookincubator/velox/issues/12981
224225
});
225226

226227
referenceQueryRunner = std::make_shared<PrestoQueryRunner>(

velox/functions/prestosql/Probability.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,5 +355,33 @@ struct InverseGammaCDFFunction {
355355
}
356356
};
357357

358+
template <typename T>
359+
struct InverseBinomialCDFFunction {
360+
VELOX_DEFINE_FUNCTION_TYPES(T);
361+
362+
FOLLY_ALWAYS_INLINE void call(
363+
int32_t& result,
364+
int32_t numberOfTrials,
365+
double successProbability,
366+
double p) {
367+
static constexpr double kInf = std::numeric_limits<double>::infinity();
368+
369+
VELOX_USER_CHECK(
370+
(p >= 0) && (p <= 1) && (p != kInf),
371+
"inverseBinomialCdf Function: p must be in the interval [0, 1]");
372+
VELOX_USER_CHECK(
373+
(successProbability >= 0) && (successProbability <= 1) &&
374+
(successProbability != kInf),
375+
"inverseBinomialCdf Function: successProbability must be in the interval [0, 1]");
376+
VELOX_USER_CHECK(
377+
numberOfTrials > 0,
378+
"inverseBinomialCdf Function: numberOfTrials must be greater than 0");
379+
380+
boost::math::binomial_distribution<> dist(
381+
numberOfTrials, successProbability);
382+
result = static_cast<int32_t>(boost::math::quantile(dist, p));
383+
}
384+
};
385+
358386
} // namespace
359387
} // namespace facebook::velox::functions

velox/functions/prestosql/registration/ProbabilityTrigonometricFunctionsRegistration.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ void registerProbTrigFunctions(const std::string& prefix) {
8181
{prefix + "inverse_laplace_cdf"});
8282
registerFunction<InverseGammaCDFFunction, double, double, double, double>(
8383
{prefix + "inverse_gamma_cdf"});
84+
registerFunction<
85+
InverseBinomialCDFFunction,
86+
int32_t,
87+
int32_t,
88+
double,
89+
double>({prefix + "inverse_binomial_cdf"});
8490
}
8591

8692
} // namespace

velox/functions/prestosql/tests/ProbabilityTest.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,5 +674,56 @@ TEST_F(ProbabilityTest, invGammaCDF) {
674674
"inverseGammaCdf Function: p must be in the interval [0, 1]");
675675
}
676676

677+
TEST_F(ProbabilityTest, invBinomialCDF) {
678+
const auto invBinomialCDF = [&](std::optional<int32_t> numberOfTrials,
679+
std::optional<double> successProbability,
680+
std::optional<double> p) {
681+
return evaluateOnce<int32_t>(
682+
"inverse_binomial_cdf(c0, c1, c2)",
683+
numberOfTrials,
684+
successProbability,
685+
p);
686+
};
687+
688+
EXPECT_EQ(0, invBinomialCDF(20, 0.5, 0.0));
689+
EXPECT_EQ(10, invBinomialCDF(20, 0.5, 0.5));
690+
EXPECT_EQ(20, invBinomialCDF(20, 0.5, 1.0));
691+
EXPECT_EQ(INT32_MAX, invBinomialCDF(INT32_MAX, 0.5, 1));
692+
EXPECT_EQ(611204, invBinomialCDF(1223340, 0.5, 0.2));
693+
694+
EXPECT_EQ(std::nullopt, invBinomialCDF(std::nullopt, 1, 1));
695+
EXPECT_EQ(std::nullopt, invBinomialCDF(1, std::nullopt, 1));
696+
EXPECT_EQ(std::nullopt, invBinomialCDF(1, 1, std::nullopt));
697+
698+
VELOX_ASSERT_THROW(
699+
invBinomialCDF(5, -0.5, 0.3),
700+
"inverseBinomialCdf Function: successProbability must be in the interval [0, 1]");
701+
VELOX_ASSERT_THROW(
702+
invBinomialCDF(5, 1.5, 0.3),
703+
"inverseBinomialCdf Function: successProbability must be in the interval [0, 1]");
704+
VELOX_ASSERT_THROW(
705+
invBinomialCDF(5, 0.5, -3.0),
706+
"inverseBinomialCdf Function: p must be in the interval [0, 1]");
707+
VELOX_ASSERT_THROW(
708+
invBinomialCDF(5, 0.5, 3.0),
709+
"inverseBinomialCdf Function: p must be in the interval [0, 1]");
710+
VELOX_ASSERT_THROW(
711+
invBinomialCDF(-5, 0.5, 0.3),
712+
"inverseBinomialCdf Function: numberOfTrials must be greater than 0");
713+
714+
VELOX_ASSERT_THROW(
715+
invBinomialCDF(1, kInf, 0.5),
716+
"inverseBinomialCdf Function: successProbability must be in the interval [0, 1]");
717+
VELOX_ASSERT_THROW(
718+
invBinomialCDF(1, kNan, 0.5),
719+
"inverseBinomialCdf Function: successProbability must be in the interval [0, 1]");
720+
VELOX_ASSERT_THROW(
721+
invBinomialCDF(1, 0.5, kInf),
722+
"inverseBinomialCdf Function: p must be in the interval [0, 1]");
723+
VELOX_ASSERT_THROW(
724+
invBinomialCDF(1, 0.5, kNan),
725+
"inverseBinomialCdf Function: p must be in the interval [0, 1]");
726+
}
727+
677728
} // namespace
678729
} // namespace facebook::velox

0 commit comments

Comments
 (0)