Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigamma kernel dev #10117

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cb8ffea
digamma op dev
youxiudeshouyeren Apr 1, 2023
fec26e3
unittest
youxiudeshouyeren Apr 1, 2023
44e1340
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 1, 2023
8aac0c7
refine
youxiudeshouyeren Apr 2, 2023
b1fe15b
tensor.digamma api
youxiudeshouyeren Apr 2, 2023
0e92082
flow.digamma api
youxiudeshouyeren Apr 2, 2023
a08eada
fix test
youxiudeshouyeren Apr 2, 2023
b60259f
unittest
youxiudeshouyeren Apr 2, 2023
e7a3e19
Merge branch 'digamma_op_dev' of github.com:youxiudeshouyeren/oneflow…
youxiudeshouyeren Apr 2, 2023
2e0248e
fmt
youxiudeshouyeren Apr 2, 2023
98ebc8b
auto fmt
youxiudeshouyeren Apr 2, 2023
d8e5b0f
add api psi
youxiudeshouyeren Apr 2, 2023
98cf0a7
docstr
youxiudeshouyeren Apr 2, 2023
b0702e2
fmt
youxiudeshouyeren Apr 2, 2023
c1759ec
fix docstr
youxiudeshouyeren Apr 2, 2023
4b41d24
refine
youxiudeshouyeren Apr 3, 2023
17fe31a
fmt
youxiudeshouyeren Apr 3, 2023
8b80a76
fmt
youxiudeshouyeren Apr 3, 2023
a86d041
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 3, 2023
fe44442
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 6, 2023
e126953
add references
youxiudeshouyeren Apr 7, 2023
e1105d9
Merge branch 'digamma_op_dev' of github.com:youxiudeshouyeren/oneflow…
youxiudeshouyeren Apr 7, 2023
809892d
fmt
youxiudeshouyeren Apr 7, 2023
7130878
fix build
youxiudeshouyeren Apr 7, 2023
d1f63ea
fmt
youxiudeshouyeren Apr 7, 2023
6af454e
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 7, 2023
75ff719
Merge branch 'youxiudeshouyeren-digamma_op_dev'
youxiudeshouyeren Apr 8, 2023
6c0055f
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 8, 2023
190aabb
fix
youxiudeshouyeren Apr 8, 2023
f146263
trigamma kernel
youxiudeshouyeren Apr 12, 2023
ae48027
Merge branch 'master' into trigamma_kernel_dev
youxiudeshouyeren Apr 12, 2023
4f183f6
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
youxiudeshouyeren Apr 16, 2023
e8bcd14
fix
youxiudeshouyeren Apr 16, 2023
97d5bf0
unittest
youxiudeshouyeren Apr 16, 2023
e537fd7
fmt
youxiudeshouyeren Apr 16, 2023
b689e18
fmt
youxiudeshouyeren Apr 16, 2023
2472a7b
Merge branch 'master' into trigamma_kernel_dev
youxiudeshouyeren Apr 20, 2023
eb1929c
Merge branch 'master' into trigamma_kernel_dev
mergify[bot] Apr 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace primitive {
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kDigamma) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kTrigamma) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp) \
Expand Down
26 changes: 19 additions & 7 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"

#include "oneflow/core/ep/cpu/primitive/unary_functor.h"
namespace oneflow {

namespace ep {
Expand Down Expand Up @@ -353,13 +353,25 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfcBackwardWithDyX, Src, Dst>
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, Src, Dst> {
template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, float, float> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// TODO:shijiaxing: This function is named trigamma, it will be implemented soon.
UNIMPLEMENTED();
return 0;
OF_DEVICE_FUNC float operator()(float dy, float x) const {
ep::primitive::UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, float, float>
trigamma_functor(0, 0);
float trigamma_result = trigamma_functor(x);
return trigamma_result * dy;
}
};

template<>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, double, double> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC double operator()(double dy, double x) const {
ep::primitive::UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, double, double>
trigamma_functor(0, 0);
double trigamma_result = trigamma_functor(x);
return trigamma_result * dy;
}
};

Expand Down
51 changes: 51 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,56 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, double, double> {
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, double, double> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double x) const {
// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L336-L352
double sign = +1;
double result = 0;
if (x < 0.5) {
sign = -1;
const double sin_pi_x = sin(pi<double> * x);
result -= (pi<double> * pi<double>) / (sin_pi_x * sin_pi_x);
x = 1 - x;
}
for (int i = 0; i < 6; ++i) {
result += 1 / (x * x);
x += 1;
}
const double ixx = 1 / (x * x);
result += (1 + 1 / (2 * x) + ixx * (1. / 6 - ixx * (1. / 30 - ixx * (1. / 42)))) / x;
return sign * result;
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTrigamma, float, float> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float x) const {
// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L354-L370
float sign = +1;
float result = 0;
if (x < 0.5f) {
sign = -1;
const float sin_pi_x = sinf(pi<float> * x);
result -= (pi<float> * pi<float>) / (sin_pi_x * sin_pi_x);
x = 1 - x;
}
for (int i = 0; i < 6; ++i) {
result += 1 / (x * x);
x += 1;
}
const float ixx = 1 / (x * x);
result += (1 + 1 / (2 * x) + ixx * (1.f / 6 - ixx * (1.f / 30 - ixx * (1.f / 42)))) / x;
return sign * result;
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kAbs, bfloat16, bfloat16> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -322,6 +372,7 @@ SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, bfloat16> {
Expand Down
10 changes: 7 additions & 3 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/

#include "oneflow/core/ep/common/primitive/binary_functor.h"
#include "oneflow/core/ep/cuda/primitive/unary_functor.cuh"

namespace oneflow {
namespace ep {
Expand Down Expand Up @@ -243,10 +244,13 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kIsClose, Src, Dst> {
template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kDigammaBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// TODO:shijiaxing: This function is named trigamma, it will be implemented soon.
assert(false);
return static_cast<Dst>(0.0);
ep::primitive::UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrigamma, Src, Dst> trigamma_functor(
0, 0);
Src trigamma_result = trigamma_functor(x);
return trigamma_result * dy;
return 0.0;
}
};

Expand Down
38 changes: 38 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH
#define ONEFLOW_CORE_EP_CUDA_PRIMITIVE_UNARY_FUNCTOR_CUH
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/cuda/elementwise.cuh"
Expand Down Expand Up @@ -283,6 +286,38 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kDigamma, Dst, Src> {
}
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrigamma, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src x) const {
// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L387-L410
const Src PI{3.14159265358979323846};
Src sign = 1;
Src result = 0;

if (x < Src{0.5}) {
sign = -1;
Src sin_pi_x = sin(PI * x);
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
x = 1 - x;
}

for (int i = 0; i < 6; ++i) {
result += Src{1} / (x * x);
x += 1;
}

const Src one{1};
const Src ixx = one / (x * x);
result += (one + one / (Src{2} * x)
+ ixx * (one / Src{6} - ixx * (one / Src{30} - ixx * (one / Src{42}))))
/ x;
return sign * result;
}
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kAbs, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -412,6 +447,7 @@ SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCeil);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCos);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCosh);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kTrigamma);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErf);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErfc);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp);
Expand Down Expand Up @@ -505,6 +541,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
Expand Down Expand Up @@ -543,3 +580,4 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, nv_bfloat16, nv_bfloat16
} // namespace primitive
} // namespace ep
} // namespace oneflow
#endif
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ enum class UnaryOp {
kCos,
kCosh,
kDigamma,
kTrigamma,
kErf,
kErfc,
kExp,
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3346,3 +3346,8 @@
- name: "digamma_grad"
signature: "Tensor (Tensor x, Tensor dy) => DigammaGrad"
bind_python: False

- name: "trigamma"
signature: "Tensor (Tensor x) => Trigamma"
bind_python: False

2 changes: 2 additions & 0 deletions oneflow/core/functional/impl/unary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ OF_PP_FOR_EACH_TUPLE(FLOAT_UNARY_BWD_WITH_FILL_FUNCTORS, FLOAT_UNARY_FUNC_BWD_WI

UNARY_ELEMENTWISE_FUNCTOR("negative", Negative, FloatUnaryFunctor)
UNARY_ELEMENTWISE_FUNCTOR("bitwise_not", BitwiseNot, UnaryFunctor)
UNARY_ELEMENTWISE_FUNCTOR("trigamma", Trigamma, FloatUnaryFunctor)

} // namespace impl

Expand Down Expand Up @@ -214,6 +215,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<InplaceFloorFunctor>("Floor_");
m.add_functor<InplaceCeilFunctor>("Ceil_");
m.add_functor<InplaceRoundFunctor>("Round_");
m.add_functor<TrigammaFunctor>("Trigamma");
};

#undef ADD_UNARY_FUNCTOR_WITH_DY_X
Expand Down
13 changes: 13 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4637,6 +4637,19 @@ def OneFlow_DigammaGradOp : OneFlow_BaseOp<"digamma_grad", [NoMemoryEffect, Decl
let has_data_type_infer_fn = 1;
}

def OneFlow_TrigammaOp : OneFlow_BaseOp<"trigamma", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_LogOp : OneFlow_BaseOp<"log", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ("cos", ep::primitive::UnaryOp::kCos) \
OF_PP_MAKE_TUPLE_SEQ("cosh", ep::primitive::UnaryOp::kCosh) \
OF_PP_MAKE_TUPLE_SEQ("digamma", ep::primitive::UnaryOp::kDigamma) \
OF_PP_MAKE_TUPLE_SEQ("trigamma", ep::primitive::UnaryOp::kTrigamma) \
OF_PP_MAKE_TUPLE_SEQ("erf", ep::primitive::UnaryOp::kErf) \
OF_PP_MAKE_TUPLE_SEQ("erfc", ep::primitive::UnaryOp::kErfc) \
OF_PP_MAKE_TUPLE_SEQ("exp", ep::primitive::UnaryOp::kExp) \
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/math_unary_elementwise_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD_WITH_FILL,
// Negative's grad function = negative(dy), so here register negative op separately.
MATH_ELEMENTWISE_DEFAULT_SET_FUNC(NegativeOp)
MATH_ELEMENTWISE_DEFAULT_SET_FUNC(BitwiseNotOp)

MATH_ELEMENTWISE_DEFAULT_SET_FUNC(TrigammaOp)
} // namespace oneflow
1 change: 1 addition & 0 deletions oneflow/user/ops/math_unary_elementwise_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace oneflow {
OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \
OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \
OF_PP_MAKE_TUPLE_SEQ("digamma", Digamma) \
OF_PP_MAKE_TUPLE_SEQ("trigamma", Trigamma) \
OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \
OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \
OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \
Expand Down
3 changes: 1 addition & 2 deletions python/oneflow/test/modules/test_global_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def _test_atan2(test_case, placement, sbp, ndim):
return z


# TODO:shijiaxing When the grad function be implemented, rm "auto_backward=False"
@autotest(n=1, auto_backward=False)
@autotest(n=1)
def _test_digamma(test_case, placement, sbp, ndim):
dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]
x = random_tensor(ndim, *dim_list, low=0, high=10).to_global(placement, sbp)
Expand Down
3 changes: 1 addition & 2 deletions python/oneflow/test/modules/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,7 @@ def test_log10_with_random_data(test_case):

@flow.unittest.skip_unless_1n1d()
class TestDigammaModule(flow.unittest.TestCase):
# TODO:shijiaxing When the grad function be implemented, rm "auto_backward=False"
@autotest(n=5, auto_backward=False)
@autotest(n=5)
def test_digamma_with_random_data(test_case):
device = random_device()
x = random_tensor().to(device)
Expand Down
6 changes: 2 additions & 4 deletions python/oneflow/test/modules/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,15 @@ def test_flow_logsumexp_with_random_data(test_case):
y = torch.special.logsumexp(x, dim=np.random.randint(0, 3))
return y

# TODO:shijiaxing When the grad function be implemented, rm "auto_backward=False"
@autotest(n=5, auto_backward=False)
@autotest(n=5, auto_backward="auto")
def test_flow_digamma_with_random_data(test_case):
device = random_device()
x_dtype = random_dtype(["arithmetic", "half"])
x = random_tensor().to(device).to(x_dtype)
y = torch.special.digamma(x)
return y

# TODO:shijiaxing When the grad function be implemented, rm "auto_backward=False"
@autotest(n=5, auto_backward=False)
@autotest(n=5, auto_backward="auto")
def test_flow_psi_with_random_data(test_case):
device = random_device()
x_dtype = random_dtype(["arithmetic", "half"])
Expand Down
3 changes: 1 addition & 2 deletions python/oneflow/test/tensor/test_tensor_part_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,8 +936,7 @@ def test_construct_global_tensor_by_numpy(test_case):
)
test_case.assertTrue(y_default_dtype.dtype == flow.int32)

# TODO:shijiaxing When the grad function be implemented, rm "auto_backward=False"
@autotest(n=5, auto_backward=False)
@autotest(n=5)
def test_digamma_tensor_with_random_data(test_case):
device = random_device()
x = random_tensor().to(device)
Expand Down