- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.9k
[NPU] Support npu kernel for flatten_contiguous_range op, test=develop #34642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU] Support npu kernel for flatten_contiguous_range op, test=develop #34642
Conversation
| Thanks for your contribution! | 
728f3ce    to
    9ee888c      
    Compare
  
    |  | ||
| REGISTER_OP_NPU_KERNEL( | ||
| flatten_contiguous_range, | ||
| ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要 int16 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现有的paddle没有支持int16,所以NPU上也没有支持这个数据类型。
| pass_library(graph_to_program_pass base) | ||
| pass_library(graph_viz_pass base) | ||
| pass_library(lock_free_optimize_pass base) | ||
| pass_library(lock_free_optimize_pass base DEPS string_helper) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改动是做什么呢 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一些编译上的错误需要修改,这个应该是不需要提交的,我改一下
| DTYPE_2_ACL_DTYPE = { | ||
| {framework::proto::VarType::BOOL, ACL_BOOL}, | ||
| {framework::proto::VarType::UINT8, ACL_UINT8}, | ||
| {framework::proto::VarType::INT8, ACL_INT8}, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只增加了uint8,其它的需要支持吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他的已经有映射了,所以是不需要的
| } | ||
|  | ||
|  | ||
| class TestFlattenOp_uint8(TestFlattenOp): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试好像没有覆盖所有的类型?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试中覆盖了float64,float32,int,uint8,int8,int64。跟要求的6个数据类型是对应的。
| "must be less than or equal to %d, but the value received is %d.", | ||
| MAX_RANK_SUPPORTED, rank)); | ||
| switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) } | ||
| switch (rank) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改动是所有op都要做的吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了
| pass_library(graph_to_program_pass base) | ||
| pass_library(graph_viz_pass base) | ||
| pass_library(lock_free_optimize_pass base) | ||
| pass_library(lock_free_optimize_pass base DEPS string_helper) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了
| "must be less than or equal to %d, but the value received is %d.", | ||
| MAX_RANK_SUPPORTED, rank)); | ||
| switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) } | ||
| switch (rank) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了
| @@ -0,0 +1,228 @@ | |||
| # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
时间是2021
|  | ||
| def test_check_grad(self): | ||
| pass | ||
| #self.check_grad_with_place(self.place, ['X'], 'Out') | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉注释
| "stop_axis": self.stop_axis | ||
| } | ||
|  | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考test_flatten_contiguout_range_op.py 增加对Python API的单测
… add_flatten_contiguous_range_npu_op
… add_flatten_contiguous_range_npu_op
| self.assertRaises(ValueError, test_InputError) | ||
|  | ||
| def test_Negative(): | ||
| paddle.disable_static() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成 paddle.disable_static(paddle.NPUPlace(0)): 不然这个就是跑在CPU上的
| x = x.astype('float32') | ||
|  | ||
| def test_Negative(): | ||
| paddle.disable_static() | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Describe
描述:
1.完成flatten_contiguous_range在NPU上的移植,支持float32,float64,uint8, int8, int, int64数据类型
2.通过单测

3.单测在NPU上执行

4.关于int8数据类型,之前代码没有将int8和NPU中的INT8对应,因而NPU识别不了INT8类型。在npu_op_runner.cc中增加了映射关系。
###根据review意见进行了修改
1.增加了python API单测,用于测试相关接口函数。
2.关于flatten_函数。test_flatten_contiguout_range_op.py中在静态图和动态图中分别测试了flatten_函数。静态图中的flatten_函数转化成了flatten函数,NPU中的输入输出采用不同的地址空间,程序可以运行。而动态图中,NPU中的输入输出占用了同一块地址空间。NPU不支持这种inplace操作。特此说明。