Skip to content

[Paddle TensorRT] Support isnan, group_norm and take_along_axis #70817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 16, 2025

Conversation

ckl117
Copy link
Contributor

@ckl117 ckl117 commented Jan 13, 2025

PR Category

Inference

PR Types

Others

Description

card-71500
add isnan, group_norm and take_along_axis

Copy link

paddle-bot bot commented Jan 13, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if !IS_TRT_VERSION_GE(8200)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pir-trt只支持8.6以上,这个不用检查了

return false;
#else
pir::Value index_var_name = op.operand_source(1);
auto index_var_name_type =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个使用pir::GetDataTypeFromValue

pir::Value index_var_name = op.operand_source(1);
auto index_var_name_type =
index_var_name.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto index_shape = index_var_name_type.dims();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

得到输入的shape使用pir::GetShapeFromValue

def isnan_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
version_list = get_trt_version_list()
if version_list >= [10, 1, 0]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10.1版本先不需要写


equal_tensor = trt_equal(network, input_tensor, input_tensor)
layer = network.add_unary(equal_tensor, trt.UnaryOperation.NOT)
cast_layer = network.add_identity(layer.get_output(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

旧ir-trt没出现cast_layer,看一下为啥加这个

@@ -152,3 +157,39 @@ def instance_norm_converter(network, paddle_op, inputs):
)
instance_norm_layer = network.add_plugin_v2(instance_norm_inputs, plugin)
return instance_norm_layer.get_output(0)


@converter_registry.register(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

旧ir-trt是采用的通用plugin方式,这个是在哪里写的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRT8.6新增normalization层,可以实现NCHW的group_norm,layer_norm的convert也是用的这个接口实现,TRT没有适配这个接口。
torch和onnx也用的这种方式。

self.max_shape = {"X": [5, 4, 10], "Index": [5, 4, 10]}

def test_trt_result(self):
self.check_trt_result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试下fp16

self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试下fp16

self.max_shape = {"x": [6, 32, 64, 64]}

def test_trt_result(self):
self.check_trt_result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试下fp16

@lizexu123 lizexu123 merged commit 646f6c6 into PaddlePaddle:develop Jan 16, 2025
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants