Skip to content

Commit

Permalink
Revert "torch.special.gamma (pytorch#78904)"
Browse files Browse the repository at this point in the history
This reverts commit f563f25.

Reverted pytorch#78904 on behalf of https://github.com/suo due to This PR appears to have broken mac tests on master https://hud.pytorch.org/pytorch/pytorch/commit/f563f25efd6226d1a4f21cd8340b2b0380abac04
  • Loading branch information
pytorchmergebot committed Jun 28, 2022
1 parent 5da776d commit 602c38f
Show file tree
Hide file tree
Showing 15 changed files with 0 additions and 490 deletions.
154 changes: 0 additions & 154 deletions aten/src/ATen/native/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -3013,160 +3013,6 @@ static inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
} // chebyshev_polynomial_w_forward(T x, T n)

template<typename T>
static inline C10_HOST_DEVICE
typename std::enable_if<std::is_floating_point<T>::value, T>::type
gamma_forward(T x) {
static const T P[] = {
+1.60119522476751861407e-4,
+1.19135147006586384913e-3,
+1.04213797561761569935e-2,
+4.76367800457137231464e-2,
+2.07448227648435975150e-1,
+4.94214826801497100753e-1,
+9.99999999999999996796e-1,
};

static const T Q[] = {
-2.31581873324120129819e-5,
+5.39605580493303397842e-4,
-4.45641913851797240494e-3,
+1.18139785222060435552e-2,
+3.58236398605498653373e-2,
-2.34591795718243348568e-1,
+7.14304917030273074085e-2,
+1.00000000000000000320e+0,
};

static const T R[] = {
+7.87311395793093628397e-4,
-2.29549961613378126380e-4,
-2.68132617805781232825e-3,
+3.47222221605458667310e-3,
+8.33333333333482257126e-2,
};

int sign_gamma = 1;

if (!std::isfinite(x)) {
return x;
}

if (std::abs(x) > T(33.0)) {
if (x < T(0.0)) {
T p = std::floor(std::abs(x));

if (p == std::abs(x)) {
return std::numeric_limits<T>::infinity();
}

int previous_p = p;

if ((previous_p & 1) == 0) {
sign_gamma = -1;
}

T z = std::abs(x) - p;

if (z > T(0.5)) {
z = std::abs(x) - (p + T(1.0));
}

z = std::abs(x) * std::sin(c10::pi<T> * z);

if (z == T(0.0)) {
return sign_gamma * std::numeric_limits<T>::infinity();
}

if (std::abs(x) >= T(171.624376956302725)) {
return std::numeric_limits<T>::infinity();
}

T r = 0.0;

for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / std::abs(x)) + R[index];
}

if (std::abs(x) > T(143.01608)) {
return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) * (std::pow(std::abs(x), T(0.5) * std::abs(x) - T(0.25)) / std::exp(std::abs(x)))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
}

return sign_gamma * c10::pi<T> / (std::abs(z) * (T(2.50662827463100050242) * (std::pow(std::abs(x), std::abs(x) - T(0.5)) / std::exp(std::abs(x))) * (T(1.0) + T(1.0) / std::abs(x) * r)));
}

if (x >= T(171.624376956302725)) {
return std::numeric_limits<T>::infinity();
}

T r = 0.0;

for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / x) + R[index];
}

if (x > T(143.01608)) {
return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, T(0.5) * x - T(0.25)) * (std::pow(x, T(0.5) * x - T(0.25)) / std::exp(x))) * (T(1.0) + T(1.0) / x * r));
}

return sign_gamma * (T(2.50662827463100050242) * (std::pow(x, x - T(0.5)) / std::exp(x)) * (T(1.0) + T(1.0) / x * r));
}

T z = 1.0;

while (x >= T(3.0)) {
x = x - T(1.0);

z = z * x;
}

while (x < T(0.0)) {
if (x > -0.000000001) {
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}

return z / ((T(1.0) + c10::euler<T> * x) * x);
}

z = z / x;

x = x + T(1.0);
}

while (x < T(2.0)) {
if (x < 0.000000001) {
if (x == T(0.0)) {
return std::numeric_limits<T>::infinity();
}

return z / ((T(1.0) + c10::euler<T> * x) * x);
}

z = z / x;

x = x + T(1.0);
}

if (x == T(2.0)) {
return z;
}

T p = 0.0;

for (uint8_t index = 0; index <= 6; index++) {
p = p * (x - T(2.0)) + P[index];
}

T q = 0.0;

for (uint8_t index = 0; index <= 7; index++) {
q = q * (x - T(2.0)) + Q[index];
}

return z * p / q;
} // T gamma_forward(T x)

template<typename T>
static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
if (n < 0) {
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_j1)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y0)
CREATE_UNARY_FLOAT_META_FUNC(special_bessel_y1)
CREATE_UNARY_FLOAT_META_FUNC(special_gamma)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i0)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_i1)
CREATE_UNARY_FLOAT_META_FUNC(special_modified_bessel_k0)
Expand Down Expand Up @@ -207,7 +206,6 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y0_out, special_bessel_y0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_y1_out, special_bessel_y1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_gamma_out, special_gamma_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i0_out, special_modified_bessel_i0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_i1_out, special_modified_bessel_i1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_modified_bessel_k0_out, special_modified_bessel_k0_stub)
Expand Down Expand Up @@ -892,7 +890,6 @@ DEFINE_DISPATCH(special_bessel_j0_stub); // NOLINT(cppcoreguidelines-avoid-non-c
DEFINE_DISPATCH(special_bessel_j1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_bessel_y1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_gamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_i1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_modified_bessel_k0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/UnaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub);
DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub);
DECLARE_DISPATCH(unary_fn, special_gamma_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub);
DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub);
Expand Down
40 changes: 0 additions & 40 deletions aten/src/ATen/native/cpu/gamma.cpp

This file was deleted.

156 changes: 0 additions & 156 deletions aten/src/ATen/native/cuda/Math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2119,162 +2119,6 @@ const auto chebyshev_polynomial_w_string = jiterator_stringify(
} // chebyshev_polynomial_w_forward(T x, T n)
); // chebyshev_polynomial_w_string

const auto gamma_string = jiterator_stringify(
template<typename T>
T gamma_forward(T x) {
static const T P[] = {
+1.60119522476751861407e-4,
+1.19135147006586384913e-3,
+1.04213797561761569935e-2,
+4.76367800457137231464e-2,
+2.07448227648435975150e-1,
+4.94214826801497100753e-1,
+9.99999999999999996796e-1,
};

static const T Q[] = {
-2.31581873324120129819e-5,
+5.39605580493303397842e-4,
-4.45641913851797240494e-3,
+1.18139785222060435552e-2,
+3.58236398605498653373e-2,
-2.34591795718243348568e-1,
+7.14304917030273074085e-2,
+1.00000000000000000320e+0,
};

static const T R[] = {
+7.87311395793093628397e-4,
-2.29549961613378126380e-4,
-2.68132617805781232825e-3,
+3.47222221605458667310e-3,
+8.33333333333482257126e-2,
};

constexpr T PI = 3.14159265358979323846;

int sign_gamma = 1;

if (!isfinite(x)) {
return x;
}

if (abs(x) > T(33.0)) {
if (x < T(0.0)) {
T p = floor(abs(x));

if (p == abs(x)) {
return INFINITY;
}

int previous_p = p;

if ((previous_p & 1) == 0) {
sign_gamma = -1;
}

T z = abs(x) - p;

if (z > T(0.5)) {
z = abs(x) - (p + T(1.0));
}

z = abs(x) * sin(PI * z);

if (z == T(0.0)) {
return sign_gamma * INFINITY;
}

if (abs(x) >= T(171.624376956302725)) {
return INFINITY;
}

T r = 0.0;

for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / abs(x)) + R[index];
}

if (abs(x) > T(143.01608)) {
return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) * (pow(abs(x), T(0.5) * abs(x) - T(0.25)) / exp(abs(x)))) * (T(1.0) + T(1.0) / abs(x) * r)));
}

return sign_gamma * PI / (abs(z) * (T(2.50662827463100050242) * (pow(abs(x), abs(x) - T(0.5)) / exp(abs(x))) * (T(1.0) + T(1.0) / abs(x) * r)));
}

if (x >= T(171.624376956302725)) {
return INFINITY;
}

T r = 0.0;

for (uint8_t index = 0; index <= 4; index++) {
r = r * (T(1.0) / x) + R[index];
}

if (x > T(143.01608)) {
return sign_gamma * (T(2.50662827463100050242) * (pow(x, T(0.5) * x - T(0.25)) * (pow(x, T(0.5) * x - T(0.25)) / exp(x))) * (T(1.0) + T(1.0) / x * r));
}

return sign_gamma * (T(2.50662827463100050242) * (pow(x, x - T(0.5)) / exp(x)) * (T(1.0) + T(1.0) / x * r));
}

T z = 1.0;

while (x >= T(3.0)) {
x = x - T(1.0);

z = z * x;
}

while (x < T(0.0)) {
if (x > -0.000000001) {
if (x == T(0.0)) {
return INFINITY;
}

return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
}

z = z / x;

x = x + T(1.0);
}

while (x < T(2.0)) {
if (x < 0.000000001) {
if (x == T(0.0)) {
return INFINITY;
}

return z / ((T(1.0) + T(0.5772156649015329) * x) * x);
}

z = z / x;

x = x + T(1.0);
}

if (x == T(2.0)) {
return z;
}

T p = 0.0;

for (uint8_t index = 0; index <= 6; index++) {
p = p * (x - T(2.0)) + P[index];
}

T q = 0.0;

for (uint8_t index = 0; index <= 7; index++) {
q = q * (x - T(2.0)) + Q[index];
}

return z * p / q;
} // T gamma_forward(T x)
); // gamma_string

const auto hermite_polynomial_h_string = jiterator_stringify(
template<typename T>
T hermite_polynomial_h_forward(T x, int64_t n) {
Expand Down
Loading

0 comments on commit 602c38f

Please sign in to comment.