Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.13】【关联 PR】Added uint8&int8&int16 support for compare_kernel -part #58209

Merged
merged 22 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,18 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
uint8_t, \
int8_t, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这里对所有的compare kernel 都添加了int8_t 类型的注册?但单测里只加了less_than的测试

int16_t, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,10 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
uint8_t, \
int8_t, \
int16_t, \
int64_t, \
float, \
double, \
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/legacy/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
Expand All @@ -131,6 +133,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
uint8_t, \
int8_t, \
int16_t, \
int, \
int64_t, \
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/legacy/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT,
phi::LessThanRawKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
Expand All @@ -155,8 +157,10 @@ PD_REGISTER_KERNEL(less_than_raw,
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
uint8_t, \
int16_t, \
int, \
int8_t, \
int64_t, \
float, \
double, \
Expand Down
60 changes: 48 additions & 12 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def equal(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, int32, int64.
x (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Tensor, data type is bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -553,6 +553,9 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -567,6 +570,9 @@ def equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -611,8 +617,8 @@ def greater_equal(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -641,6 +647,9 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -655,6 +664,9 @@ def greater_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -699,8 +711,8 @@ def greater_than(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand Down Expand Up @@ -729,6 +741,9 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -743,6 +758,9 @@ def greater_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -787,8 +805,8 @@ def less_equal(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -818,6 +836,9 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -832,6 +853,9 @@ def less_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -876,8 +900,8 @@ def less_than(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float16, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -907,6 +931,9 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -921,6 +948,9 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down Expand Up @@ -965,8 +995,8 @@ def not_equal(x, y, name=None):
The output has no gradient.

Args:
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
x (Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64.
y (Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, uint8, int8, int16, int32, int64.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Expand Down Expand Up @@ -996,6 +1026,9 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand All @@ -1010,6 +1043,9 @@ def not_equal(x, y, name=None):
"float16",
"float32",
"float64",
"uint8",
"int8",
"int16",
"int32",
"int64",
"uint16",
Expand Down
25 changes: 18 additions & 7 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,35 @@ def setUp(self):
def test_output(self):
self.check_output(check_cinn=True, check_pir=check_pir)

def test_errors(self):
def test_int16_support(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32')
y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32')
a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16')
b = paddle.static.data(name='b', shape=[-1, 2], dtype='int16')
op = eval("paddle.%s" % self.op_type)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

46-51为什么要修改呢,用原来的方式不行么?

Copy link
Contributor Author

@jjyaoao jjyaoao Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

46-51为什么要修改呢,用原来的方式不行么?

是的,如果用原来的就会报类型错误,因为这个单侧本来的意义在于,防止compare的kernel能够使用int16,一旦检测到int16类型后会自动抛异常,对于50和51行而言:
50行检查当 x 是 int32 类型而 y 是 int16 类型时,调用 op 是否会引发错误。
51行检查当 x 是 int16 类型而 y 是 int32 类型时,同样的情况。
然而这个pr的功能本身就是支持int16,所以必须得调整这个地方

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那前面单测的名字就不应该改回来了 test_int16_support 比较合适

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那前面单测的名字就不应该改回来了 test_int16_support 比较合适

好的好的,最后那个文档风格检测可能是我修改了一部分源码里的注释导致的,可能需要看一看怎么检查一下


try:
result = op(x=a, y=b)
except TypeError:
self.fail("TypeError should not be raised for int16 inputs")

cls_name = f"{op_type}_{typename}"
Cls.__name__ = cls_name
globals()[cls_name] = Cls


for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}:
for _type_name in {
'float32',
'float64',
'uint8',
'int8',
'int16',
'int32',
'int64',
'float16',
}:
if _type_name == 'float64' and core.is_compiled_with_rocm():
_type_name = 'float32'
if _type_name == 'float16' and (not core.is_compiled_with_cuda()):
Expand Down Expand Up @@ -513,7 +524,7 @@ def test_check_output(self):


class TestCompareOpError(unittest.TestCase):
def test_errors(self):
def test_int16_support(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么改名呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么改名呢?

当时想的是功能变化了,最开始这里描述的是一旦遇到int16的输出就报错,我把他的内容改为检测是否支持int16的test,我现在调整回来~

paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand Down