Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#11 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
添加了NonZero的Marker和pd_op.concat的converter
  • Loading branch information
lizexu123 authored Jul 30, 2024
2 parents ed5d356 + c66382b commit 793a606
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 47 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ DEFINE_GENERAL_PATTERN(Reshape, paddle::dialect::ReshapeOp)
DEFINE_GENERAL_PATTERN(Dropout, paddle::dialect::DropoutOp)
DEFINE_GENERAL_PATTERN(Bmm, paddle::dialect::BmmOp)
DEFINE_GENERAL_PATTERN(Concat, paddle::dialect::ConcatOp)
DEFINE_GENERAL_PATTERN(Nonzero, paddle::dialect::NonzeroOp)

DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue,
paddle::dialect::FusedGemmEpilogueOp)
Expand Down Expand Up @@ -751,6 +752,7 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
output_lengths.push_back(attr.dyn_cast<pir::Int64Attribute>().data());
}
axis += (axis < 0) ? x_shape.size() : 0;

if (x_shape[axis] == -1) {
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
return false;
Expand Down Expand Up @@ -794,6 +796,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(Conv2d)
ADD_PATTERN(FusedConv2dAddAct)
ADD_PATTERN(DepthwiseConv2d)
ADD_PATTERN(Nonzero)

#undef ADD_PATTERN
ps.Add(std::make_unique<Pool2dOpPattern>(context));
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/pir/transforms/trt_sub_graph_extract_pass.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -20,9 +20,9 @@
#include <string>
#include <unordered_map>

#include "paddle/common/flags.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/common/flags.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/builder.h"
Expand All @@ -41,16 +41,16 @@ using GroupOpsVec = std::vector<pir::Operation*>;

bool IsSupportedByTRT(const pir::Operation& op) {
if (op.HasAttribute(paddle::dialect::kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(paddle::dialect::kCanRunTrtAttr).data()) {
op.attribute<pir::BoolAttribute>(paddle::dialect::kCanRunTrtAttr)
.data()) {
return true;
}
return false;
}

class TrtSubGraphExtractPass : public pir::Pass {
public:
TrtSubGraphExtractPass()
: pir::Pass("trt_sub_graph_extract_pass", 1) {}
TrtSubGraphExtractPass() : pir::Pass("trt_sub_graph_extract_pass", 1) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
Expand All @@ -64,11 +64,13 @@ class TrtSubGraphExtractPass : public pir::Pass {
::pir::SubgraphDetector(&block, IsSupportedByTRT)();
AddStatistics(groups.size());
for (auto& group_ops : groups) {
if(group_ops.size() < FLAGS_trt_min_group_size) {
VLOG(4) << "current group_ops.size(): " << group_ops.size() << ", will fallback to paddle original graph";
if (group_ops.size() < FLAGS_trt_min_group_size) {
VLOG(4) << "current group_ops.size(): " << group_ops.size()
<< ", will fallback to paddle original graph";
continue;
}
VLOG(4) << "current group_ops.size(): " << group_ops.size() << ", will lower to TensorRT graph";
VLOG(4) << "current group_ops.size(): " << group_ops.size()
<< ", will lower to TensorRT graph";
::pir::ReplaceWithGroupOp(&block, group_ops);
}
}
Expand Down
32 changes: 17 additions & 15 deletions python/paddle/pp_tensorrt/impls/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ def layernorm_converter(network, paddle_op, inputs):
f"{bias_tensor.name}_broadcast",
len(input_a.shape) - len(bias_tensor.shape),
)
# _logger.info(
# f"!!! layernorm, {input_a.shape}, {scale_tensor.shape}, {bias_tensor.shape}"
# )

layer_norm = network.add_normalization(
input_a, scale_tensor, bias_tensor, axes
Expand Down Expand Up @@ -251,21 +248,25 @@ def conv2d_converter(network, paddle_op, inputs):

return conv_layer


@converter_registry.register("pd_op.nonzero", trt_version="8.x")
def non_zero_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
cast_layer = network.add_cast(input_tensor, trt.float32)
non_zero_layer = network.add_non_zero(cast_layer.get_output(0))

return non_zero_layer


@converter_registry.register("pd_op.gather_nd", trt_version="8.x")
def gather_nd_converter(network, paddle_op, inputs):
input_tensor, indices_tensor = inputs
shuffle_layer = network.add_shuffle(indices_tensor)
shuffle_layer.first_transpose = trt.Permutation([1, 0])
# import pdb;pdb.set_trace()
non_zero_layer = network.add_gather_v2(input_tensor, shuffle_layer.get_output(0), trt.GatherMode.ND)
non_zero_layer = network.add_gather_v2(
input_tensor, shuffle_layer.get_output(0), trt.GatherMode.ND
)
return non_zero_layer


Expand Down Expand Up @@ -391,16 +392,6 @@ def batch_norm_converter(network, paddle_op, inputs):
return batch_norm_layer


@converter_registry.register("pd_op.full")
def full_converter(network, paddle_op, inputs):
shape = paddle_op.attrs()["shape"]
value = paddle_op.attrs().get("value", 1.0) # 默认值为1.0
full_tensor = network.add_constant(
shape, np.full(shape, value, dtype=np.float32)
)
return full_tensor


@converter_registry.register("pd_op.flatten", trt_version="8.x")
def flatten_converter(network, paddle_op, inputs):
input_val = inputs[0]
Expand Down Expand Up @@ -483,3 +474,14 @@ def flatten_converter(network, paddle_op, inputs):
flatten_layer.set_input(1, final_shape_layer.get_output(0))

return flatten_layer


@converter_registry.register("pd_op.concat")
def concat_converter(network, paddle_op, inputs):
input_tensor, axis = inputs
concat_layer = network.add_concatenation(inputs=input_tensor)
if axis < 0:
axis = len(input_tensor.shape) + axis

concat_layer.axis = axis
return concat_layer
60 changes: 53 additions & 7 deletions python/paddle/pp_tensorrt/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from converter import PaddleToTensorRTConverter
from util import (
enforce_op_lower_trt,
forbid_op_lower_trt,
get_bert_program,
get_dummy_program,
get_idg_program,
get_mlp_program,
get_r50_program,
predict_program,
run_pir_pass,
Expand Down Expand Up @@ -212,12 +214,6 @@ def test_paddle_to_tensorrt_conversion_idg():

# Step3: run pir pass(including some fusion pass and trt_op_marker_pass)
program = run_pir_pass(program, partition_mode=False)
enforce_op_lower_trt(program, "pd_op.gather_nd")
enforce_op_lower_trt(program, "pd_op.nonzero")
# enforce_op_lower_trt(program, "pd_op.pool2d")
# enforce_op_lower_trt(program, "pd_op.batch_norm_")
# enforce_op_lower_trt(program, "pd_op.flatten")
# forbid_op_lower_trt(program, "pd_op.flatten")

# Step4: run trt_sub_graph_extract_pass()
program_with_pir = run_pir_pass(program, partition_mode=True)
Expand All @@ -226,6 +222,7 @@ def test_paddle_to_tensorrt_conversion_idg():
converter = PaddleToTensorRTConverter(program_with_pir, scope)
converter.convert_program_to_trt()

output_var = program_with_pir.list_vars()[-1]
# Step6: run inference(converted_program)
output_converted = predict_program(
program_with_pir,
Expand All @@ -249,8 +246,57 @@ def test_paddle_to_tensorrt_conversion_idg():
print("output_converted", output_converted)


def test_paddle_to_tensorrt_conversion_mlp():
program, scope, param_dict = get_mlp_program()
input_data_min_shape = np.random.randn(1, 512, 1024).astype('float32')
input_data_max_shape = np.random.randn(2, 512, 1024).astype('float32')

# Step1.1: get original results(for tests only)
output_var = program.list_vars()[-1]
output_expected = predict_program(
program, {"input": input_data_min_shape}, [output_var]
)

# Step2: run warmup for collecting shape
warmup_shape_infer(
program,
min_shape_feed={"input": input_data_min_shape},
max_shape_feed={"input": input_data_max_shape},
)

# Step3: run pir pass(including some fusion pass and trt_op_marker_pass)
program = run_pir_pass(program, partition_mode=False)
# forbid_op_lower_trt(program,"pd_op.concat")

# Step4: run trt_sub_graph_extract_pass()
program_with_pir = run_pir_pass(program, partition_mode=True)

# Step5: run TRTConverter(would lower group_op into tensorrt_engine_op)
converter = PaddleToTensorRTConverter(program_with_pir, scope)
converter.convert_program_to_trt()

output_var = program_with_pir.list_vars()[-1]
# Step6: run inference(converted_program)
output_converted = predict_program(
program_with_pir, {"input": input_data_min_shape}, [output_var]
)

# Check that the results are close to each other within a tolerance of 1e-3
np.testing.assert_allclose(
output_expected[0],
output_converted[0],
rtol=1e-3,
atol=1e-3,
err_msg="Outputs are not within the 1e-3 tolerance",
)

print("output_expected", output_expected)
print("output_converted", output_converted)


if __name__ == "__main__":
# test_paddle_to_tensorrt_conversion_dummy()
# test_paddle_to_tensorrt_conversion_bert()
# test_paddle_to_tensorrt_conversion_r50()
test_paddle_to_tensorrt_conversion_idg()
# test_paddle_to_tensorrt_conversion_idg()
test_paddle_to_tensorrt_conversion_mlp()
Loading

0 comments on commit 793a606

Please sign in to comment.