-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.13】【关联 PR】Added uint8&int8&int16 support for compare_kernel -part #58209
Changes from all commits
c8d4a5b
3171348
30b6d70
fe55923
f68d1fb
22cb913
dddb082
8fac339
cec3af6
ec55f18
ba17ff1
5f2d860
1d0d78a
2663b85
aee9b38
596cb74
691ec6c
b3a215a
7b07471
0101631
a47a57e
9bf6328
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,24 +38,35 @@ def setUp(self): | |
def test_output(self): | ||
self.check_output(check_cinn=True, check_pir=check_pir) | ||
|
||
def test_errors(self): | ||
def test_int16_support(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard( | ||
paddle.static.Program(), paddle.static.Program() | ||
): | ||
x = paddle.static.data(name='x', shape=[-1, 2], dtype='int32') | ||
y = paddle.static.data(name='y', shape=[-1, 2], dtype='int32') | ||
a = paddle.static.data(name='a', shape=[-1, 2], dtype='int16') | ||
b = paddle.static.data(name='b', shape=[-1, 2], dtype='int16') | ||
op = eval("paddle.%s" % self.op_type) | ||
self.assertRaises(TypeError, op, x=x, y=a) | ||
self.assertRaises(TypeError, op, x=a, y=y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 46-51为什么要修改呢,用原来的方式不行么? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
是的,如果用原来的就会报类型错误,因为这个单侧本来的意义在于,防止compare的kernel能够使用int16,一旦检测到int16类型后会自动抛异常,对于50和51行而言: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那前面单测的名字就不应该改回来了 test_int16_support 比较合适 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
好的好的,最后那个文档风格检测可能是我修改了一部分源码里的注释导致的,可能需要看一看怎么检查一下 |
||
|
||
try: | ||
result = op(x=a, y=b) | ||
except TypeError: | ||
self.fail("TypeError should not be raised for int16 inputs") | ||
|
||
cls_name = f"{op_type}_{typename}" | ||
Cls.__name__ = cls_name | ||
globals()[cls_name] = Cls | ||
|
||
|
||
for _type_name in {'float32', 'float64', 'int32', 'int64', 'float16'}: | ||
for _type_name in { | ||
'float32', | ||
'float64', | ||
'uint8', | ||
'int8', | ||
'int16', | ||
'int32', | ||
'int64', | ||
'float16', | ||
}: | ||
if _type_name == 'float64' and core.is_compiled_with_rocm(): | ||
_type_name = 'float32' | ||
if _type_name == 'float16' and (not core.is_compiled_with_cuda()): | ||
|
@@ -513,7 +524,7 @@ def test_check_output(self): | |
|
||
|
||
class TestCompareOpError(unittest.TestCase): | ||
def test_errors(self): | ||
def test_int16_support(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为什么改名呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
当时想的是功能变化了,最开始这里描述的是一旦遇到int16的输出就报错,我把他的内容改为检测是否支持int16的test,我现在调整回来~ |
||
paddle.enable_static() | ||
with paddle.static.program_guard( | ||
paddle.static.Program(), paddle.static.Program() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你这里对所有的compare kernel 都添加了int8_t 类型的注册?但单测里只加了less_than的测试