Skip to content

Commit

Permalink
【Hackathon 5th No.13】【关联 PR】Added int8 support for less_than
Browse files Browse the repository at this point in the history
  • Loading branch information
jjyaoao committed Oct 20, 2023
1 parent 6f48047 commit c8d4a5b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ PD_REGISTER_KERNEL(equal_all,
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int8_t, \
int16_t, \
int, \
int64_t, \
Expand Down
20 changes: 19 additions & 1 deletion paddle/phi/kernels/kps/compare_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,29 @@ PD_REGISTER_KERNEL(equal_all,
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)

#define PD_REGISTER_LESS_THAN_KERNEL(func) \
PD_REGISTER_KERNEL(less_than, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int8_t, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}

PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)

#endif
2 changes: 2 additions & 0 deletions python/paddle/tensor/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int32",
"int64",
"uint16",
Expand All @@ -921,6 +922,7 @@ def less_than(x, y, name=None):
"float16",
"float32",
"float64",
"int8",
"int32",
"int64",
"uint16",
Expand Down
15 changes: 15 additions & 0 deletions test/legacy_test/test_compare_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,21 @@ def test_place_2(self):
self.assertEqual((result.numpy() == np.array([False])).all(), True)


class TestLessThanInt8(unittest.TestCase):
def test_less_than_int8(self):
# Create a tensor of type int8
x = paddle.to_tensor([1, 2, 3], dtype='int8')
y = paddle.to_tensor([1, 3, 2], dtype='int8')

result = paddle.less_than(x, y)

# desired output
expected = np.array([False, True, False])

# Verify output
self.assertTrue((result.numpy() == expected).all())


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit c8d4a5b

Please sign in to comment.