Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev Zeta op #10189

Merged
merged 20 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/special.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ The oneflow.special module, modeled after SciPy's special module.
logsumexp
round
softmax
zeta
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseOr) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kBitwiseXor)

#define BINARY_MATH_FLOATING_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kZeta)

#define BINARY_ACTIVATION_BACKWARD_OP_SEQ_0 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kIdentityBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEluBackwardWithDyX) \
Expand Down
68 changes: 68 additions & 0 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,74 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kLgammaBackwardWithDyX, Src, Ds
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kZeta, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src x, Src q) const {
// ref
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L235-L309
const Src MACHEP = Src{1.11022302462515654042E-16};
constexpr Src zero = Src{0.0};
constexpr Src half = Src{0.5};
constexpr Src one = Src{1.0};
static const Src A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};
int i = 0;
Src a, b, k, s, t, w;
if (x == one) { return std::numeric_limits<Dst>::infinity(); }

if (x < one) { return std::numeric_limits<Dst>::quiet_NaN(); }

if (q <= zero) {
if (q == floor(q)) { return std::numeric_limits<Dst>::infinity(); }
if (x != floor(x)) { return std::numeric_limits<Dst>::quiet_NaN(); }
}

s = pow(q, -x);
a = q;
i = 0;
b = zero;
while ((i < 9) || (a <= Src{9.0})) {
i += 1;
a += one;
b = pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) { return static_cast<Dst>(s); }
};

w = a;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = fabs(t / s);
if (t < MACHEP) { return static_cast<Dst>(s); }
k += one;
a *= x + k;
b /= w;
k += one;
}
return static_cast<Dst>(s);
}
};

#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCPU, op, type, type> { \
Expand Down
26 changes: 15 additions & 11 deletions oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,24 +575,28 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
BINARY_COMPLEX_MATH_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_LOGICAL_OP_SEQ BINARY_COMPARISION_OP_SEQ,
NDARRAY_BINARY_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_FLOATING_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPLEX_COMPARISION_OP_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)
BINARY_LOGICAL_OP_SEQ BINARY_COMPARISION_OP_SEQ,
NDARRAY_BINARY_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPLEX_COMPARISION_OP_SEQ,
CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_BOOL_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_BACKWARD_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)};
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_BACKWARD_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)};

#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
Expand Down
79 changes: 79 additions & 0 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,83 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kLgammaBackwardWithDyX, Src, D
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kZeta, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src x, Src q) const {
// ref
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L302-L384
const Src MACHEP{1.11022302462515654042E-16};
constexpr Src zero{0};
constexpr Src half{0.5};
constexpr Src one{1};
static const Src A[] = {
12.0,
-720.0,
30240.0,
-1209600.0,
47900160.0,
-1.8924375803183791606e9, /*1.307674368e12/691*/
7.47242496e10,
-2.950130727918164224e12, /*1.067062284288e16/3617*/
1.1646782814350067249e14, /*5.109094217170944e18/43867*/
-4.5979787224074726105e15, /*8.028576626982912e20/174611*/
1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
-7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
};

int i = 0;
Src a, b, k, s, t, w;

// Short-circuits x -> +infty
if (x == one) { return INFINITY; }

// Short-circuits x < 1 -> NaN
if (x < one) { return NAN; }

// Short-circuits negative q integers map to +infty,
// negative q non-integers map to NaN
if (q <= zero) {
if (q == floor(q)) { return INFINITY; }
if (x != floor(x)) { return NAN; }
}

s = pow(q, -x);
a = q;
i = 0;
b = zero;
while ((i < 9) || (a <= Src{9.0})) {
i += 1;
a += one;
b = pow(a, -x);
s += b;
if ((-MACHEP * s < b) && (b < MACHEP * s)) { return s; }
}

w = a;
s += b * w / (x - one);
s -= half * b;
a = one;
k = zero;
for (int i = 0; i < 12; i++) {
a *= x + k;
b /= w;
t = a * b / A[i];
s = s + t;
t = fabs(t / s);

if (t < MACHEP) { return s; }

k += one;
a *= x + k;
b /= w;
k += one;
}

return s;
}
};

#define SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(op, type) \
template<typename Dst> \
struct BinaryFunctor<DeviceType::kCUDA, op, type, Dst> { \
Expand Down Expand Up @@ -305,6 +382,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFmod);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorDiv);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTruncDiv);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorMod);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kZeta);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kIdentityBackwardWithDyX);
Expand Down Expand Up @@ -382,6 +460,7 @@ SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFmod);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorDiv);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTruncDiv);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorMod);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kZeta);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarBasePowerGrad);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kScalarExpPowerGrad);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
Expand Down
25 changes: 16 additions & 9 deletions oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,29 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
CUDA_PRIMITIVE_REAL_TYPE_SEQ, CUDA_PRIMITIVE_BOOL_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_ACTIVATION_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY,
BINARY_COMPLEX_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_MATH_BACKWARD_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)
BINARY_ACTIVATION_BACKWARD_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_BITWISE_OP_SEQ,
CUDA_PRIMITIVE_INT_TYPE_SEQ CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY,
BINARY_MATH_BACKWARD_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_BITWISE_OP_SEQ,
CUDA_PRIMITIVE_INT_TYPE_SEQ
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)};

#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,14 @@ std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary(Scalar
new BroadcastElementwiseBinaryImpl<binary_op, Src, Dst>(attr0, attr1));
}

#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY,
BINARY_MATH_FLOATING_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ);

} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ enum class BinaryOp {
kFloorMod,
kScalarBasePowerGrad,
kScalarExpPowerGrad,
kZeta,
// Comparision
kEqual,
kNotEqual,
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3362,3 +3362,10 @@
signature: "Tensor (Tensor x) => Trigamma"
bind_python: False

- name: "zeta"
signature: [
"Tensor (Tensor x, Tensor other) => BroadcastZeta",
"Tensor (Scalar x, Tensor other) => ZetaScalarTensor",
"Tensor (Tensor x, Scalar other) => ZetaTensorScalar",
]
bind_python: True
26 changes: 26 additions & 0 deletions oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,29 @@ class ScalarDivByTensorFunctor : public BinaryFunctor {
}
};

class BroadcastZetaFunctor : public BinaryFloatFunctor {
public:
BroadcastZetaFunctor() {
op_ = CHECK_JUST(one::OpBuilder("broadcast_zeta").Input("x").Input("y").Output("z").Build());
}
};

class ZetaScalarTensorFunctor {
public:
Maybe<Tensor> operator()(const Scalar x, const std::shared_ptr<one::Tensor>& y) const {
auto scalar_tensor = JUST(functional::FullLike(y, x)); // wrap scalar to tensor
return functional::BroadcastZeta(scalar_tensor, y);
}
};

class ZetaTensorScalarFunctor {
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Scalar y) const {
auto scalar_tensor = JUST(functional::FullLike(x, y)); // wrap scalar to tensor
return functional::BroadcastZeta(x, scalar_tensor);
}
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand Down Expand Up @@ -824,6 +847,9 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LerpFunctor>("Lerp");
m.add_functor<impl::InplaceLerpFunctor>("InplaceLerp");
m.add_functor<impl::LerpGradFunctor>("LerpGrad");
m.add_functor<impl::BroadcastZetaFunctor>("BroadcastZeta");
m.add_functor<impl::ZetaScalarTensorFunctor>("ZetaScalarTensor");
m.add_functor<impl::ZetaTensorScalarFunctor>("ZetaTensorScalar");
};

} // namespace functional
Expand Down
22 changes: 12 additions & 10 deletions oneflow/core/ndarray/binary_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,6 @@ struct BinaryFuncLE final {
};
SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncLE);

template<typename T>
struct BinaryFuncIEN final {
// placeholder, no definition required, the type is only used to generate Op
};

template<typename T>
struct BinaryFuncINN final {
// placeholder, no definition required, the type is only used to generate Op
};

template<typename T>
struct BinaryFuncAND final {
static OF_DEVICE_FUNC bool Invoke(const T x, const T y) { return x && y; }
Expand Down Expand Up @@ -623,6 +613,18 @@ SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAny, GetZeroVal);
SPECIALIZE_UNIT_OF_BINARY_FUNC(BinaryFuncAll, GetOneVal);
#undef SPECIALIZE_UNIT_OF_BINARY_FUNC

/*
These placeholder specializations are used for `GetBinaryBroadcastSbpSignature` in
oneflow/user/ops/math_binary_broadcast_ops.cpp
*/
#define SPECIALIZE_FOR_SBP(binary_func) \
template<typename T> \
struct binary_func final {};

SPECIALIZE_FOR_SBP(BinaryFuncIEN);
SPECIALIZE_FOR_SBP(BinaryFuncINN);
SPECIALIZE_FOR_SBP(BinaryFuncZeta);
#undef SPECIALIZE_FOR_SBP
} // namespace oneflow

#endif // ONEFLOW_CORE_NDARRAY_BINARY_FUNC_H_
14 changes: 14 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,20 @@ def OneFlow_XlogyOp : OneFlow_BaseOp<"xlogy", [NoMemoryEffect, DeclareOpInterfac
let has_data_type_infer_fn = 1;
}

def OneFlow_BroadcastZetaOp : OneFlow_BaseOp<"broadcast_zeta", [NoGrad,NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$y
);
let output = (outs
OneFlow_Tensor:$z
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

#endif // GET_ONEFLOW_BINARY_OP_DEFINITIONS


Expand Down
Loading