From c8d4a5b87cba0272d1be2ccc642dc8feac9e56da Mon Sep 17 00:00:00 2001 From: jjyaoao Date: Wed, 18 Oct 2023 10:16:10 +0000 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.13=E3=80=91?= =?UTF-8?q?=E3=80=90=E5=85=B3=E8=81=94=20PR=E3=80=91Added=20int8=20support?= =?UTF-8?q?=20for=20less=5Fthan?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/cpu/compare_kernel.cc | 1 + paddle/phi/kernels/kps/compare_kernel.cu | 20 +++++++++++++++++++- python/paddle/tensor/logic.py | 2 ++ test/legacy_test/test_compare_op.py | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index 24b4615daa58c..d290e348941aa 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -110,6 +110,7 @@ PD_REGISTER_KERNEL(equal_all, ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ + int8_t, \ int16_t, \ int, \ int64_t, \ diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 14bb86b475320..d1211954e6d40 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -160,11 +160,29 @@ PD_REGISTER_KERNEL(equal_all, kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } -PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual) PD_REGISTER_COMPARE_KERNEL(equal, Equal) PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual) +#define PD_REGISTER_LESS_THAN_KERNEL(func) \ + PD_REGISTER_KERNEL(less_than, \ + KPS, \ + ALL_LAYOUT, \ + phi::func##Kernel, \ + bool, \ + int8_t, \ + int16_t, \ + int, \ + int64_t, \ + float, \ + double, \ + phi::dtype::float16, \ + phi::dtype::bfloat16) { \ + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ + } + +PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) + #endif diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 9b50993b89166..220c373d78136 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -907,6 +907,7 @@ def less_than(x, y, name=None): "float16", "float32", "float64", + "int8", "int32", "int64", "uint16", @@ -921,6 +922,7 @@ def less_than(x, y, name=None): "float16", "float32", "float64", + "int8", "int32", "int64", "uint16", diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 91dce088ef88e..9afcc825457a3 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -615,6 +615,21 @@ def test_place_2(self): self.assertEqual((result.numpy() == np.array([False])).all(), True) +class TestLessThanInt8(unittest.TestCase): + def test_less_than_int8(self): + # Create a tensor of type int8 + x = paddle.to_tensor([1, 2, 3], dtype='int8') + y = paddle.to_tensor([1, 3, 2], dtype='int8') + + result = paddle.less_than(x, y) + + # desired output + expected = np.array([False, True, False]) + + # Verify output + self.assertTrue((result.numpy() == expected).all()) + + if __name__ == '__main__': paddle.enable_static() unittest.main()