Skip to content

Conversation

@Liu-xiandong
Copy link
Member

@Liu-xiandong Liu-xiandong commented Aug 5, 2021

PR types

New features

PR changes

OPs

Describe

描述:
1.完成flatten_contiguous_range在NPU上的移植,支持float32,float64,uint8, int8, int, int64数据类型

2.通过单测
image

3.单测在NPU上执行
image

4.关于int8数据类型,之前代码没有将int8和NPU中的INT8对应,因而NPU识别不了INT8类型。在npu_op_runner.cc中增加了映射关系。

###根据review意见进行了修改
1.增加了python API单测,用于测试相关接口函数。

2.关于flatten_函数。test_flatten_contiguout_range_op.py中在静态图和动态图中分别测试了flatten_函数。静态图中的flatten_函数转化成了flatten函数,NPU中的输入输出采用不同的地址空间,程序可以运行。而动态图中,NPU中的输入输出占用了同一块地址空间。NPU不支持这种inplace操作。特此说明。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Aug 5, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Liu-xiandong Liu-xiandong force-pushed the add_flatten_contiguous_range_npu_op branch from 728f3ce to 9ee888c Compare August 6, 2021 03:54

REGISTER_OP_NPU_KERNEL(
flatten_contiguous_range,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 int16 吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现有的paddle没有支持int16,所以NPU上也没有支持这个数据类型。

pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是做什么呢 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一些编译上的错误需要修改,这个应该是不需要提交的,我改一下

DTYPE_2_ACL_DTYPE = {
{framework::proto::VarType::BOOL, ACL_BOOL},
{framework::proto::VarType::UINT8, ACL_UINT8},
{framework::proto::VarType::INT8, ACL_INT8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只增加了uint8,其它的需要支持吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他的已经有映射了,所以是不需要的

}


class TestFlattenOp_uint8(TestFlattenOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试好像没有覆盖所有的类型?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试中覆盖了float64,float32,int,uint8,int8,int64。跟要求的6个数据类型是对应的。

"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
switch (rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动是所有op都要做的吗?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了

pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了

"must be less than or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
switch (rank) { REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED) }
switch (rank) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个和当前的PR #34656 会冲突,不需要修改,等前面那个PR合入就好了

@@ -0,0 +1,228 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

时间是2021


def test_check_grad(self):
pass
#self.check_grad_with_place(self.place, ['X'], 'Out')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉注释

"stop_axis": self.stop_axis
}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考test_flatten_contiguout_range_op.py 增加对Python API的单测

self.assertRaises(ValueError, test_InputError)

def test_Negative():
paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成 paddle.disable_static(paddle.NPUPlace(0)): 不然这个就是跑在CPU上的

x = x.astype('float32')

def test_Negative():
paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

qili93
qili93 previously approved these changes Aug 9, 2021
Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@qili93 qili93 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@qili93 qili93 merged commit 79be842 into PaddlePaddle:develop Aug 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants