Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalar math kernel use primitive #8612

Merged
merged 24 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
753cb3d
scalar math use primitive
guo-ran Jul 11, 2022
a5f14a7
fix
guo-ran Jul 11, 2022
df3ca58
Merge branch 'master' into dev_scalar_math_primitive_kernel
guo-ran Jul 11, 2022
d24f825
rm useless code
guo-ran Jul 11, 2022
15fc093
Merge branch 'dev_scalar_math_primitive_kernel' of https://github.com…
guo-ran Jul 11, 2022
c8a3093
Merge branch 'master' into dev_scalar_math_primitive_kernel
guo-ran Jul 11, 2022
998d5bb
add div and fix bug
guo-ran Jul 11, 2022
8288ebf
Merge branch 'dev_scalar_math_primitive_kernel' of https://github.com…
guo-ran Jul 11, 2022
5d42214
Merge branch 'master' into dev_scalar_math_primitive_kernel
guo-ran Jul 12, 2022
fe90ccd
broadcast floormod and fmod
guo-ran Jul 12, 2022
b545556
Merge branch 'dev_scalar_math_primitive_kernel' of https://github.com…
guo-ran Jul 12, 2022
ace0100
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
guo-ran Jul 12, 2022
5edf039
add test
guo-ran Jul 12, 2022
5c2b669
Merge branch 'master' into dev_scalar_math_primitive_kernel
guo-ran Jul 13, 2022
2f71ea0
Merge branch 'master' into dev_scalar_math_primitive_kernel
guo-ran Aug 1, 2022
df842e4
address review
guo-ran Aug 1, 2022
4dfcde9
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
7b0559f
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
9b98dcc
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
83a4eff
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
1db570f
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
8e14528
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
bc07209
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 1, 2022
57d36b4
Merge branch 'master' into dev_scalar_math_primitive_kernel
mergify[bot] Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions oneflow/core/ep/common/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,54 @@ struct BinaryFunctor<device, BinaryOp::kLogicalXor, Src, Dst> {
}
};

template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFmod, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast<Dst>(src0 % src1); }
};

template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFloorDiv, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return src0 / src1; }
};

template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kFloorMod, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const {
Src trunc_mod = src0 % src1;
return (trunc_mod != static_cast<Src>(0))
&& ((src1 < static_cast<Src>(0)) != (trunc_mod < static_cast<Src>(0)))
? trunc_mod + src1
: trunc_mod;
}
};

template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint8_t, uint8_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC uint8_t operator()(uint8_t src0, uint8_t src1) const { return src0 % src1; }
};

template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint32_t, uint32_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC uint32_t operator()(uint32_t src0, uint32_t src1) const { return src0 % src1; }
};

template<DeviceType device>
struct BinaryFunctor<device, BinaryOp::kFloorMod, uint64_t, uint64_t> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC uint64_t operator()(uint64_t src0, uint64_t src1) const { return src0 % src1; }
};

template<DeviceType device, typename Src, typename Dst>
struct BinaryFunctor<device, BinaryOp::kEluBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : alpha(attr0.Value<double>()) {}
Expand Down
19 changes: 11 additions & 8 deletions oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,17 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
return true;
}

#define BINARY_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow)
#define BINARY_MATH_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFmod) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorDiv) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFloorMod)

#define BINARY_COMPARISION_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \
Expand Down
105 changes: 100 additions & 5 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,95 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, Src, Dst> {
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, bool, bool> {
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const {
return static_cast<bool>(std::pow(static_cast<double>(src0), static_cast<double>(src1)));
OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::pow(static_cast<float>(src0), static_cast<float>(src1)));
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kPow, float16, float16> {
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::fmod(src0, src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src0, double src1) const { return std::fmod(src0, src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFmod, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::pow(static_cast<float>(src0), static_cast<float>(src1)));
return static_cast<float16>(std::fmod(static_cast<float>(src0), static_cast<float>(src1)));
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src0, float src1) const { return std::floor(src0 / src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src0, double src1) const {
return std::floor(src0 / src1);
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorDiv, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(std::floor(static_cast<float>(src0) / static_cast<float>(src1)));
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src0, float src1) const {
float trunc_mod = fmod(src0, src1);
return (trunc_mod != static_cast<float>(0))
&& ((src1 < static_cast<float>(0)) != (trunc_mod < static_cast<float>(0)))
? trunc_mod + src1
: trunc_mod;
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src0, double src1) const {
double trunc_mod = fmod(src0, src1);
return (trunc_mod != static_cast<double>(0))
&& ((src1 < static_cast<double>(0)) != (trunc_mod < static_cast<double>(0)))
? trunc_mod + src1
: trunc_mod;
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float16, float16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {}
BinaryFunctor<DeviceType::kCPU, BinaryOp::kFloorMod, float, float> float_functor;

OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const {
return static_cast<float16>(float_functor(static_cast<float>(src0), static_cast<float>(src1)));
}
};

Expand All @@ -69,6 +144,26 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyX, Src, Dst>
}
};

#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCPU, op, type, type> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCPU, op, int, int> int_functor; \
OF_DEVICE_FUNC type operator()(type src0, type src1) const { \
return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \
} \
};

SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kPow, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFmod, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char);
SPECIALIZATION_CPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char);

} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
Expand Down
85 changes: 70 additions & 15 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,56 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, bool, bool> {
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFmod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(bool src0, bool src1) const {
return static_cast<bool>(pow(static_cast<double>(src0), static_cast<double>(src1)));
OF_DEVICE_FUNC float operator()(float src0, float src1) const { return fmod(src0, src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFmod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src0, double src1) const { return fmod(src0, src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorDiv, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src0, float src1) const { return floor(src0 / src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorDiv, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src0, double src1) const { return floor(src0 / src1); }
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorMod, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src0, float src1) const {
float trunc_mod = fmod(src0, src1);
return (trunc_mod != static_cast<float>(0))
&& ((src1 < static_cast<float>(0)) != (trunc_mod < static_cast<float>(0)))
? trunc_mod + src1
: trunc_mod;
}
};

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, half, half> {
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kFloorMod, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC half operator()(half src0, half src1) const {
return static_cast<half>(pow(static_cast<float>(src0), static_cast<float>(src1)));
OF_DEVICE_FUNC double operator()(double src0, double src1) const {
double trunc_mod = fmod(src0, src1);
return (trunc_mod != static_cast<double>(0))
&& ((src1 < static_cast<double>(0)) != (trunc_mod < static_cast<double>(0)))
? trunc_mod + src1
: trunc_mod;
}
};

Expand Down Expand Up @@ -79,15 +115,6 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst

#if CUDA_VERSION >= 11000

template<>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
}
};

#define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
Expand All @@ -99,6 +126,10 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16
} \
};

SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kPow);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFmod);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorDiv);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFloorMod);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
Expand Down Expand Up @@ -129,6 +160,10 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDy
} \
};

SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kPow);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFmod);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorDiv);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFloorMod);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
Expand All @@ -143,6 +178,26 @@ SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);

#define SPECIALIZATION_GPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, type, type> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCUDA, op, int, int> int_functor; \
OF_DEVICE_FUNC type operator()(type src0, type src1) const { \
return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \
} \
};

SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, bool);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, bool);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, bool);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, bool);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kPow, char);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFmod, char);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorDiv, char);
SPECIALIZATION_GPU_BINARY_FUNCTOR(BinaryOp::kFloorMod, char);

} // namespace broadcast_elementwise_binary
} // namespace primitive
} // namespace ep
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/ep/include/primitive/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ enum class BinaryOp {
kMax,
kMin,
kPow,
kFmod,
kFloorDiv,
kFloorMod,
// Comparision
kEqual,
kNotEqual,
Expand Down
Loading