From f22b9390b113f0d8160ff94171d1ac5a0bc94ce4 Mon Sep 17 00:00:00 2001 From: jjyaoao Date: Thu, 2 Nov 2023 13:45:19 +0000 Subject: [PATCH] update v1.1 --- paddle/phi/kernels/cpu/compare_kernel.cc | 22 ++------- paddle/phi/kernels/kps/compare_kernel.cu | 23 ++-------- .../phi/kernels/legacy/cpu/compare_kernel.cc | 3 +- .../phi/kernels/legacy/kps/compare_kernel.cu | 2 + python/paddle/tensor/logic.py | 46 ++++++++++++++----- test/legacy_test/test_compare_op.py | 25 ++++------ 6 files changed, 53 insertions(+), 68 deletions(-) diff --git a/paddle/phi/kernels/cpu/compare_kernel.cc b/paddle/phi/kernels/cpu/compare_kernel.cc index e51abe3ae11ff8..5218dd8e590e91 100644 --- a/paddle/phi/kernels/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/cpu/compare_kernel.cc @@ -110,8 +110,9 @@ PD_REGISTER_KERNEL(equal_all, ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ - int16_t, \ int, \ + int8_t, \ + int16_t, \ int64_t, \ float, \ double, \ @@ -120,26 +121,9 @@ PD_REGISTER_KERNEL(equal_all, kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } +PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual) PD_REGISTER_COMPARE_KERNEL(equal, Equal) PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual) - -#define PD_REGISTER_LESS_THAN_KERNEL(name, func) \ - PD_REGISTER_KERNEL(name, \ - CPU, \ - ALL_LAYOUT, \ - phi::func##Kernel, \ - bool, \ - int8_t, \ - int16_t, \ - int, \ - int64_t, \ - float, \ - double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) { \ - kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ - } -PD_REGISTER_LESS_THAN_KERNEL(less_than, LessThan) diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index 1997b31422ef9d..adf060d76f8fcd 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -150,8 +150,9 @@ PD_REGISTER_KERNEL(equal_all, ALL_LAYOUT, \ phi::func##Kernel, \ bool, \ - int16_t, \ int, \ + int8_t, \ + int16_t, \ int64_t, \ float, \ double, \ @@ -160,29 +161,11 @@ PD_REGISTER_KERNEL(equal_all, kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ } +PD_REGISTER_COMPARE_KERNEL(less_than, LessThan) PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual) PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan) PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual) PD_REGISTER_COMPARE_KERNEL(equal, Equal) PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual) -#define PD_REGISTER_LESS_THAN_KERNEL(func) \ - PD_REGISTER_KERNEL(less_than, \ - KPS, \ - ALL_LAYOUT, \ - phi::func##Kernel, \ - bool, \ - int8_t, \ - int16_t, \ - int, \ - int64_t, \ - float, \ - double, \ - phi::dtype::float16, \ - phi::dtype::bfloat16) { \ - kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \ - } - -PD_REGISTER_LESS_THAN_KERNEL(less_than, LessThan) - #endif diff --git a/paddle/phi/kernels/legacy/cpu/compare_kernel.cc b/paddle/phi/kernels/legacy/cpu/compare_kernel.cc index d9760398af7cc6..396d0b5473b88f 100644 --- a/paddle/phi/kernels/legacy/cpu/compare_kernel.cc +++ b/paddle/phi/kernels/legacy/cpu/compare_kernel.cc @@ -115,7 +115,7 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, phi::LessThanRawKernel, bool, - int16_t, + int8_t int16_t, int, int64_t, float, @@ -131,6 +131,7 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, \ phi::func##RawKernel, \ bool, \ + int8_t, \ int16_t, \ int, \ int64_t, \ diff --git a/paddle/phi/kernels/legacy/kps/compare_kernel.cu b/paddle/phi/kernels/legacy/kps/compare_kernel.cu index 67bd491738346e..2d16851ae4ab1f 100644 --- a/paddle/phi/kernels/legacy/kps/compare_kernel.cu +++ b/paddle/phi/kernels/legacy/kps/compare_kernel.cu @@ -139,6 +139,7 @@ PD_REGISTER_KERNEL(less_than_raw, ALL_LAYOUT, phi::LessThanRawKernel, bool, + int8_t, int16_t, int, int64_t, @@ -157,6 +158,7 @@ PD_REGISTER_KERNEL(less_than_raw, bool, \ int16_t, \ int, \ + int8_t, \ int64_t, \ float, \ double, \ diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 220c373d781365..911a462f5859ec 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -512,8 +512,8 @@ def equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64. - y (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64. + x (Tensor): Tensor, data type is bool, float16, float32, float64, int8, int16, int32, int64. + y (Tensor): Tensor, data type is bool, float16, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -553,6 +553,8 @@ def equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -567,6 +569,8 @@ def equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -611,8 +615,8 @@ def greater_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -641,6 +645,8 @@ def greater_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -655,6 +661,8 @@ def greater_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -699,8 +707,8 @@ def greater_than(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -729,6 +737,8 @@ def greater_than(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -743,6 +753,8 @@ def greater_than(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -787,8 +799,8 @@ def less_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -818,6 +830,8 @@ def less_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -832,6 +846,8 @@ def less_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -876,8 +892,8 @@ def less_than(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -908,6 +924,7 @@ def less_than(x, y, name=None): "float32", "float64", "int8", + "int16", "int32", "int64", "uint16", @@ -923,6 +940,7 @@ def less_than(x, y, name=None): "float32", "float64", "int8", + "int16", "int32", "int64", "uint16", @@ -967,8 +985,8 @@ def not_equal(x, y, name=None): The output has no gradient. Args: - x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. - y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64. + x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int8, int16, int32, int64. + y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int8, int16, int32, int64. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -998,6 +1016,8 @@ def not_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", @@ -1012,6 +1032,8 @@ def not_equal(x, y, name=None): "float16", "float32", "float64", + "int8", + "int16", "int32", "int64", "uint16", diff --git a/test/legacy_test/test_compare_op.py b/test/legacy_test/test_compare_op.py index 9afcc825457a36..9b4691dfbb3156 100755 --- a/test/legacy_test/test_compare_op.py +++ b/test/legacy_test/test_compare_op.py @@ -55,7 +55,15 @@ def test_errors(self): globals()[cls_name] = Cls -for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}: +for _type_name in { + 'float32', + 'float64', + 'int8', + 'int16', + 'int32', + 'int64', + 'float16', +}: if _type_name == 'float64' and core.is_compiled_with_rocm(): _type_name = 'float32' if _type_name == 'float16' and (not core.is_compiled_with_cuda()): @@ -615,21 +623,6 @@ def test_place_2(self): self.assertEqual((result.numpy() == np.array([False])).all(), True) -class TestLessThanInt8(unittest.TestCase): - def test_less_than_int8(self): - # Create a tensor of type int8 - x = paddle.to_tensor([1, 2, 3], dtype='int8') - y = paddle.to_tensor([1, 3, 2], dtype='int8') - - result = paddle.less_than(x, y) - - # desired output - expected = np.array([False, True, False]) - - # Verify output - self.assertTrue((result.numpy() == expected).all()) - - if __name__ == '__main__': paddle.enable_static() unittest.main()