Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1 from Layssy/lw_trt
Browse files Browse the repository at this point in the history
添加castop
  • Loading branch information
lizexu123 authored Jul 19, 2024
2 parents 8f8046d + bae9c52 commit fd0959e
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 12 deletions.
30 changes: 30 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,35 @@ class FlattenOpPattern
return true;
}
};
class CastOpPattern
: public pir::OpRewritePattern<paddle::dialect::CastOp> {
public:
using pir::OpRewritePattern<paddle::dialect::CastOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::CastOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {

return false;
}

#if !IS_TRT_VERSION_GE(7000)
return false;
#endif
auto input_var_name = op.operand_source(0);
auto input_var_name_type = input_var_name.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto input_var_name_shape = input_var_name_type.dims();
if(input_var_name_shape.size()==0 || input_var_name_shape.size()==1){
VLOG(3)
<< " cast op does not support input's dim is 1 or 0 in tensorrt "
"static shape mode.";
return false;

}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
Expand Down Expand Up @@ -846,6 +875,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<SliceOpPattern>(context));
ps.Add(std::make_unique<IndexSelectOpPattern>(context));
ps.Add(std::make_unique<FlattenOpPattern>(context));
ps.Add(std::make_unique<CastOpPattern>(context));
return ps;
}
};
Expand Down
90 changes: 78 additions & 12 deletions test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,27 +352,54 @@ def test_check_output(self):
class TestIndexSelectTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
image_shape = paddle.static.data(
name='x', shape=[1, 128, 1, 1], dtype='float32'
x = paddle.static.data(
name='x', shape=[3, 4], dtype='int32'
)
index = paddle.static.data(
name='index', shape=[3], dtype='int32'
)
x = paddle.arange(
end=image_shape[0]
* image_shape[1]
* image_shape[2]
* image_shape[3]
index_select_out = paddle.index_select(x, index)
out = paddle.assign(index_select_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random([3, 4]).astype("int32"),
"index": np.random.random([3]).astype("int32"),
}

self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()

class TestCastTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True
def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[3,4], dtype='float64'
)
img = paddle.reshape(x, image_shape)
flatten_out = paddle.nn.flatten(start_axis=1, stop_axis=3)
out = paddle.assign(flatten_out)
cast_out = paddle.cast(x, 'uint8')
out = paddle.assign(cast_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random([2, 3, 4, 4]).astype("int32"),
"x": np.random.random([3,4]).astype("float64"),
}

self.fetch_list = [out]
Expand All @@ -388,6 +415,45 @@ def setUp(self):
def test_check_output(self):
self.check_pass_correct()

# class TestIndexSelectTRTPattern(PassTest):
# def is_program_valid(self, program=None):
# return True

# def sample_program(self):
# with paddle.pir_utils.IrGuard():
# main_prog = paddle.static.Program()
# start_prog = paddle.static.Program()
# with paddle.pir.core.program_guard(main_prog, start_prog):
# image_shape = paddle.static.data(
# name='x', shape=[1, 128, 1, 1], dtype='float32'
# )
# x = paddle.arange(
# end=image_shape[0]
# * image_shape[1]
# * image_shape[2]
# * image_shape[3]
# )
# img = paddle.reshape(x, image_shape)
# flatten_out = paddle.nn.flatten(start_axis=1, stop_axis=3)
# out = paddle.assign(flatten_out)
# self.pass_attr_list = [{'trt_op_marker_pass': {}}]
# self.feeds = {
# "x": np.random.random([2, 3, 4, 4]).astype("int32"),
# }

# self.fetch_list = [out]
# self.valid_op_map = {
# "pd_op.fusion_transpose_flatten_concat": 0,
# }
# yield [main_prog, start_prog], False

# def setUp(self):
# if core.is_compiled_with_cuda():
# self.places.append(paddle.CUDAPlace(0))

# def test_check_output(self):
# self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit fd0959e

Please sign in to comment.