Skip to content

Commit

Permalink
Migrate fmod and fmod_ from TH to ATen (CPU) (pytorch#33592)
Browse files Browse the repository at this point in the history
Summary:
Closes pytorch#24701
Pull Request resolved: pytorch#33592

Differential Revision: D20043875

Pulled By: ezyang

fbshipit-source-id: b8c0a4e73a3cef6e55e91bbd35f8aadca8114c56
  • Loading branch information
peterbell10 authored and facebook-github-bot committed Feb 26, 2020
1 parent f87b0b2 commit 2eb95d8
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 77 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@
return: argument 0
variants:
- function
backends:
- CUDA
options:
- cname: fmod
arguments:
Expand All @@ -640,6 +642,8 @@
name: _th_fmod_
return: argument 0
variants: function
backends:
- CUDA
options:
- cname: fmod
arguments:
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,18 @@ struct Vec256 {
Vec256<T> frac() const {
return *this - this->trunc();
}
template <
typename U = T,
typename std::enable_if_t<std::is_floating_point<U>::value, int> = 0>
Vec256<T> fmod(const Vec256<T>& q) const {
// U is for SFINAE purposes only. Make sure it is not changed.
static_assert(std::is_same<U, T>::value, "U must be T");
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = std::fmod(values[i], q[i]);
}
return ret;
}
Vec256<T> log() const {
return map(std::log);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ template <> class Vec256<double> {
Vec256<double> expm1() const {
return Vec256<double>(Sleef_expm1d4_u10(values));
}
Vec256<double> fmod(const Vec256<double>& q) const {
return Vec256<double>(Sleef_fmodd4(values, q));
}
Vec256<double> log() const {
return Vec256<double>(Sleef_logd4_u10(values));
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ template <> class Vec256<float> {
Vec256<float> expm1() const {
return Vec256<float>(Sleef_expm1f8_u10(values));
}
Vec256<float> fmod(const Vec256<float>& q) const {
return Vec256<float>(Sleef_fmodf8(values, q));
}
Vec256<float> log() const {
return Vec256<float>(Sleef_logf8_u10(values));
}
Expand Down
36 changes: 36 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ DEFINE_DISPATCH(sigmoid_backward_stub);
DEFINE_DISPATCH(tanh_backward_stub);
DEFINE_DISPATCH(max_elementwise_stub);
DEFINE_DISPATCH(min_elementwise_stub);
DEFINE_DISPATCH(fmod_stub);
DEFINE_DISPATCH(fmod_scalar_stub);

Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other,
Expand Down Expand Up @@ -586,5 +588,39 @@ Tensor min(const Tensor& self, const Tensor& other) {

Tensor& min_(Tensor& self, const Tensor& other) { return at::min_out(self, self, other); }

Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_stub(iter.device_type(), iter);
return result;
}

Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) {
auto iter = TensorIterator::unary_op(result, self,
/*check_mem_overlap=*/true);
TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU");
fmod_scalar_stub(iter.device_type(), iter, other);
return result;
}

Tensor fmod(const Tensor& self, const Tensor & other) {
Tensor result = at::empty({0}, self.options());
return at::fmod_out(result, self, other);
}

Tensor fmod(const Tensor& self, Scalar other) {
Tensor result = at::empty({0}, self.options());
return at::fmod_out(result, self, other);
}

Tensor& fmod_(Tensor& self, const Tensor& other) {
return at::fmod_out(self, self, other);
}

Tensor& fmod_(Tensor& self, Scalar other) {
return at::fmod_out(self, self, other);
}

}
} // namespace at
2 changes: 2 additions & 0 deletions aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,7 @@ DECLARE_DISPATCH(binary_fn, smooth_l1_stub);
DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub);
DECLARE_DISPATCH(binary_fn, tanh_backward_stub);
DECLARE_DISPATCH(binary_fn, mse_stub);
DECLARE_DISPATCH(binary_fn, fmod_stub);
DECLARE_DISPATCH(binary_fn_alpha, fmod_scalar_stub);

}} // namespace at::native
48 changes: 48 additions & 0 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,52 @@ void mse_kernel(TensorIterator& iter) {
});
}

void fmod_kernel(TensorIterator& iter) {
if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "fmod_cpu", [&]() {
cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t {
return x % d;
});
});
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "fmod_cpu", [&]() {
cpu_kernel_vec(
iter,
[](scalar_t x, scalar_t d) -> scalar_t {
return std::fmod(x, d);
},
[](Vec256<scalar_t> x, Vec256<scalar_t> d) {
return x.fmod(d);
});
});
}
}

void fmod_scalar_kernel(TensorIterator& iter, Scalar divisor) {
if (isIntegralType(iter.dtype(), /*includeBool=*/ false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "fmod_scalar_cpu", [&]() {
const auto div = divisor.to<scalar_t>();
cpu_kernel(iter, [=](scalar_t x) -> scalar_t {
return x % div;
});
});
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "fmod_scalar_cpu", [&]() {
const auto div = divisor.to<scalar_t>();
const auto div_vec = Vec256<scalar_t>(div);
cpu_kernel_vec(
iter,
[=](scalar_t x) -> scalar_t {
return std::fmod(x, div);
},
[=](Vec256<scalar_t> x) {
return x.fmod(div_vec);
});
});
}

}

} // anonymous namespace


Expand Down Expand Up @@ -551,5 +597,7 @@ REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
REGISTER_DISPATCH(tanh_backward_stub, &tanh_backward_kernel);
REGISTER_DISPATCH(mse_stub, &mse_kernel);
REGISTER_DISPATCH(fmod_stub, &fmod_kernel);
REGISTER_DISPATCH(fmod_scalar_stub, &fmod_scalar_kernel);

}} // namespace at::native
12 changes: 6 additions & 6 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4267,13 +4267,13 @@
- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_fmod_
CPU: fmod_
CUDA: legacy::cuda::_th_fmod_

- func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU: legacy::cpu::_th_fmod_
CPU: fmod_
CUDA: legacy::cuda::_th_fmod_

- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
Expand Down Expand Up @@ -5005,26 +5005,26 @@

- func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_fmod_out
CPU: fmod_out
CUDA: legacy::cuda::_th_fmod_out

- func: fmod.Scalar(Tensor self, Scalar other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_fmod
CPU: fmod
CUDA: legacy::cuda::_th_fmod

- func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: legacy::cpu::_th_fmod_out
CPU: fmod_out
CUDA: legacy::cuda::_th_fmod_out

- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
CPU: legacy::cpu::_th_fmod
CPU: fmod
CUDA: legacy::cuda::_th_fmod

- func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
29 changes: 0 additions & 29 deletions aten/src/TH/generic/THTensorEvenMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,35 +503,6 @@ accreal THTensor_(dot)(THTensor *tensor, THTensor *src)
return sum;
}

void THTensor_(fmod)(THTensor *r_, THTensor *t, scalar_t value)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
if (r_Contig && tContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = fmod(tp[i], value);
#else
rp[i] = tp[i] % value;
#endif
}
});
} else {

#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = fmod(*t_data, value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#else
TH_TENSOR_APPLY2_PARALLEL(r_Size, r_Contig, tContig, scalar_t, r_, scalar_t, t, *r__data = (*t_data % value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#endif
}
}

// Should wrap if the value (a) has a different sign than the divisor (b), but is not 0.
static inline bool modulo_wrap(scalar_t a, scalar_t b) {
return (a != 0) && (a < 0) != (b < 0);
Expand Down
40 changes: 0 additions & 40 deletions aten/src/TH/generic/THTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,46 +402,6 @@ void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src)
}
}

void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src)
{
THTensor_(resizeAs)(r_, t);
int64_t r_Size = THTensor_(nElement)(r_);
int64_t srcSize = THTensor_(nElement)(src);
int r_Contig = THTensor_(isContiguous)(r_);
int tContig = THTensor_(isContiguous)(t);
int srcContig = THTensor_(isContiguous)(src);
if (srcSize == r_Size){
if (r_Contig && tContig && srcContig) {
scalar_t *tp = t->data<scalar_t>();
scalar_t *sp = src->data<scalar_t>();
scalar_t *rp = r_->data<scalar_t>();
at::parallel_for(0, r_Size, TH_OMP_OVERHEAD_THRESHOLD,
[&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
rp[i] = fmod(tp[i], sp[i]);
#else
rp[i] = tp[i] % sp[i];
#endif
}
});
} else {

#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig,scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = fmod(*t_data, *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#else
TH_TENSOR_APPLY3_PARALLEL(r_Size, r_Contig, tContig, srcContig, scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = (*t_data % *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
#endif
}
} else {
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = fmod(*t_data, *src_data););
#else
TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, t, scalar_t, src, *r__data = (*t_data % *src_data););
#endif
}
}

void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src)
{
THTensor_(resizeAs)(r_, t);
Expand Down
2 changes: 0 additions & 2 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,13 @@ TH_API void THTensor_(indexFill)(THTensor *tensor, int dim, THLongTensor *index,

TH_API accreal THTensor_(dot)(THTensor *t, THTensor *src);

TH_API void THTensor_(fmod)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(remainder)(THTensor *r_, THTensor *t, scalar_t value);
TH_API void THTensor_(clamp)(THTensor *r_, THTensor *t, scalar_t min_value, scalar_t max_value);

TH_API void THTensor_(cadd)(THTensor *r_, THTensor *t, scalar_t value, THTensor *src);
TH_API void THTensor_(csub)(THTensor *self, THTensor *src1, scalar_t value, THTensor *src2);
TH_API void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(cdiv)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src);
TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src);

TH_API void THTensor_(addbmm)(THTensor *r_, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha);
Expand Down

0 comments on commit 2eb95d8

Please sign in to comment.