Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,22 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
}
#endif
// In fact, this should include all conv, not only conv2d
if (op_type == "conv2d") {
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* filter_var_desc = block->FindVar(desc.Input("Filter")[0]);
if (!filter_var_desc->Persistable()) {
VLOG(3) << "Trt not support filter is a intermediate tensor in "
"conv2d op.";
return false;
}
}
}

if (op_type == "deformable_conv") {
Expand Down Expand Up @@ -912,6 +928,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto dtype = x_var_desc->GetDataType();
// At present, forbid int64_t into trt.
if (dtype == 3) {
return false;
}
}

if (op_type == "unsqueeze2") {
Expand All @@ -931,6 +960,19 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("X")[0]);
auto dtype = x_var_desc->GetDataType();
// At present, forbid int64_t into trt.
if (dtype == 3) {
return false;
}
}

if (op_type == "batch_norm") {
Expand Down Expand Up @@ -1073,6 +1115,11 @@ bool OpTeller::Tell(const framework::ir::Node* node,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
auto dtype = x_var_desc->GetDataType();
// At present, only support float32 or float16 into trt.
if (!(dtype == 5 || dtype == 4)) {
return false;
}
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "Scale op does not support 1-dimensional input in tensorrt";
return false;
Expand Down Expand Up @@ -1163,6 +1210,20 @@ bool OpTeller::Tell(const framework::ir::Node* node,
return false;
}
}

auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(desc.Input("Input")[0]);
auto dtype = x_var_desc->GetDataType();
// At present, forbid int64_t into trt.
if (dtype == 3) {
return false;
}
}

if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
Expand Down