Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,16 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
->assert_is_op_output("box_coder")
->AsIntermediate();

auto transpose_before_nms =
pattern->NewNode(GetNodeName("transpose_before_nms"))
->assert_is_op("transpose2");

auto transpose_before_nms_out =
pattern->NewNode(GetNodeName("transpose_before_nms_out"))
->assert_is_op_output("transpose2")
->assert_is_op_input("multiclass_nms", "Scores")
->AsIntermediate();

auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
->assert_is_op("multiclass_nms")
->assert_op_has_n_inputs("multiclass_nms", 2);
Expand Down Expand Up @@ -1487,8 +1497,10 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
box_coder_out->LinksFrom({box_coder_op});

multiclass_nms_op
->LinksFrom({box_coder_out, conv_in[kMultiClassSecondInputNmsOffset]})
transpose_before_nms->LinksFrom({conv_in[kMultiClassSecondInputNmsOffset]});
transpose_before_nms_out->LinksFrom({transpose_before_nms});

multiclass_nms_op->LinksFrom({box_coder_out, transpose_before_nms_out})
.LinksTo({multiclass_nms_out});

return multiclass_nms_out;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(

input_nodes.push_back(gpd.mutable_pattern()
->NewNode("x" + std::to_string(times + 1))
->assert_is_op_input("multiclass_nms", "Scores")
->assert_is_op_input("transpose2")
->AsInput());

patterns::AnakinDetectionPattern pattern(gpd.mutable_pattern(), pattern_name);
Expand Down Expand Up @@ -106,6 +106,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
Node *box_coder_out = subgraph.at(pattern.GetPDNode("box_coder_out"));

Node *multiclass_nms_second_input = subgraph.at(input_nodes[times + 1]);
Node *transpose_before_nms =
subgraph.at(pattern.GetPDNode("transpose_before_nms"));
Node *transpose_before_nms_out =
subgraph.at(pattern.GetPDNode("transpose_before_nms_out"));

Node *multiclass_nms = subgraph.at(pattern.GetPDNode("multiclass_nms"));
Node *multiclass_nms_out =
subgraph.at(pattern.GetPDNode("multiclass_nms_out"));
Expand Down Expand Up @@ -133,11 +138,11 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
nodes[i * kNumFields + kPriorBoxLocOffset]->Name());
}

int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
// int axis = boost::get<int>(concat_op1->Op()->GetAttr("axis"));
framework::OpDesc concat1_desc;
concat1_desc.SetType("concat");
concat1_desc.SetInput("X", concat1_input_names);
concat1_desc.SetAttr("axis", axis);
concat1_desc.SetAttr("axis", 2);
concat1_desc.SetOutput("Out", {concat_out1->Name()});

auto *new_add_concat_op = graph->CreateOpNode(&concat1_desc);
Expand Down Expand Up @@ -184,6 +189,8 @@ std::unique_ptr<ir::Graph> SimplifyAnakinDetectionPatternPass<times>::ApplyImpl(
delete_nodes.insert(concat_out2);
delete_nodes.insert(box_coder_op);
delete_nodes.insert(box_coder_out);
delete_nodes.insert(transpose_before_nms);
delete_nodes.insert(transpose_before_nms_out);
delete_nodes.insert(multiclass_nms);

new_add_concat_op->outputs.push_back(concat_out1);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/anakin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cc_library(anakin_engine SRCS engine.cc)
nv_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto)
cc_library(anakin_op_teller SRCS op_teller.cc DEPS framework_proto)
target_link_libraries(anakin_engine anakin anakin_saber_common)
cc_test(test_anakin_engine SRCS test_anakin_engine.cc DEPS anakin_engine)
add_subdirectory(convert)
2 changes: 2 additions & 0 deletions paddle/fluid/inference/anakin/convert/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
auto output = op_desc.Output("Y").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Y").front();
auto epsilon = boost::get<float>(op_desc.GetAttr("epsilon"));
// auto momentum = boost::get<float>(op_desc.GetAttr("momentum"));

auto bn_op_name = op_name + ":bn";
auto bn_output = bn_op_name + "_output";
engine_->AddOp(bn_op_name, "BatchNorm", {inputs["X"]}, {bn_output});
engine_->AddOpAttr(bn_op_name, "epsilon", epsilon);
engine_->AddOpAttr(bn_op_name, "momentum", static_cast<float>(1.0));

auto scale_op_name = op_name + ":scale";
auto get_lod_tensor = [this, &scope, &op_name](const std::string &var_name,
Expand Down
31 changes: 21 additions & 10 deletions paddle/fluid/inference/anakin/convert/density_prior_box.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace paddle {
namespace inference {
namespace anakin {

void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op,
const framework::Scope &scope,
void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) {
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
Expand All @@ -42,34 +42,45 @@ void DensityPriorBoxOpConverter::operator()(const framework::proto::OpDesc &op,
auto fixed_ratios =
boost::get<std::vector<float>>(op_desc.GetAttr("fixed_ratios"));
auto densities = boost::get<std::vector<int>>(op_desc.GetAttr("densities"));
std::vector<float> dens;
for (auto& ele : densities) {
dens.push_back(static_cast<float>(ele));
}

// lack flip
auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
// auto clip = boost::get<bool>(op_desc.GetAttr("clip"));
auto variances = boost::get<std::vector<float>>(op_desc.GetAttr("variances"));
for (auto& ele : variances) {
LOG(INFO) << ele;
}

// lack img_h, img_w
auto step_h = boost::get<float>(op_desc.GetAttr("step_h"));
auto step_w = boost::get<float>(op_desc.GetAttr("step_w"));
auto offset = boost::get<float>(op_desc.GetAttr("offset"));
std::vector<std::string> order = {"MIN", "COM", "MAX"};
PTuple<std::string> t_order;
t_order.push_back("MIN");
t_order.push_back("COM");
t_order.push_back("MAX");

std::vector<float> temp_v = {};

engine_->AddOp(op_name, "PriorBox", {input_name, image_name}, {output_name});
engine_->AddOpAttr<PTuple<float>>(op_name, "min_size", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "max_size", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "aspect_ratio", temp_v);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_sizes", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratios", fixed_ratios);
engine_->AddOpAttr<PTuple<int>>(op_name, "density", densities);
engine_->AddOpAttr(op_name, "is_flip", false);
engine_->AddOpAttr(op_name, "is_clip", clip);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_size", fixed_sizes);
engine_->AddOpAttr<PTuple<float>>(op_name, "fixed_ratio", fixed_ratios);
engine_->AddOpAttr<PTuple<float>>(op_name, "density", dens);
engine_->AddOpAttr(op_name, "is_flip", static_cast<bool>(false));
engine_->AddOpAttr(op_name, "is_clip", static_cast<bool>(false));
engine_->AddOpAttr<PTuple<float>>(op_name, "variance", variances);
engine_->AddOpAttr(op_name, "img_h", static_cast<int>(0));
engine_->AddOpAttr(op_name, "img_w", static_cast<int>(0));
engine_->AddOpAttr(op_name, "step_h", step_h);
engine_->AddOpAttr(op_name, "step_w", step_w);
engine_->AddOpAttr(op_name, "offset", offset);
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", order);
engine_->AddOpAttr<PTuple<std::string>>(op_name, "order", t_order);
}

} // namespace anakin
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/inference/anakin/convert/op_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "framework/core/types.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
Expand Down Expand Up @@ -68,6 +69,35 @@ class AnakinOpConverter {
ConvertOp(op, parameters, scope, engine);
}
}

// The scope here should be inited with the parameter vars.
void ConvertBlockToAnakinEngine(
framework::BlockDesc *block_desc, const framework::Scope &scope,
const std::vector<std::string> &inputs,
const std::unordered_set<std::string> &parameters,
const std::vector<std::string> &outputs, AnakinNvEngine *engine) {
framework::proto::BlockDesc *block_proto = block_desc->Proto();
ConvertBlock(*block_proto, parameters, scope, engine);
engine->Freeze();
for (auto &input : inputs) {
if (parameters.count(input)) continue;
auto *var = block_desc->FindVar(input);
PADDLE_ENFORCE(var, "no variable called %s", input);

auto var_shape = var->GetShape();
PADDLE_ENFORCE(var_shape.size() == 4);
std::vector<int> input_shape;
for (int i = 0; i < var_shape.size(); i++) {
input_shape.push_back(var_shape[i]);
}
input_shape[0] = 1;

engine->SetInputShape(input, input_shape);
}
engine->Optimize();
engine->InitGraph();
}

void SetEngine(AnakinNvEngine *engine) { engine_ = engine; }
virtual ~AnakinOpConverter() {}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/anakin/convert/pool2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void Pool2dOpConverter::operator()(const framework::proto::OpDesc &op,
if (pool_type == "max") {
anakin_pool_type = "MAX";
} else if (pool_type == "avg") {
anakin_pool_type = "AVG";
anakin_pool_type = "AVGEXC";
} else {
PADDLE_THROW("TensorRT unsupported pooling type!");
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/anakin/convert/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void SoftMaxOpConverter::operator()(const framework::proto::OpDesc &op,
auto output = op_desc.Output("Out").front();
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
engine_->AddOp(op_name, "Softmax", {input}, {output});
engine_->AddOpAttr(op_name, "axis", 1);
engine_->AddOpAttr(op_name, "axis", 2);
}

} // namespace anakin
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/anakin/convert/test_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ TEST(batch_norm_op, test) {
desc.SetOutput("SavedVariance", {"batch_norm_save_variance"});

float eps = 1e-5f;
bool is_test = true;
desc.SetAttr("epsilon", eps);
desc.SetAttr("is_test", true);
desc.SetAttr("is_test", is_test);

validator.SetOp(*desc.Proto());

Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/inference/anakin/convert/test_pool2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,52 @@ void test_pool2d(bool global_pooling, bool ceil_mode,
validator.Execute(1);
}

void test_pool2d2(bool global_pooling, bool ceil_mode,
std::string pool_type = "max") {
auto* pool2d_converter =
Registry<AnakinOpConverter>::Global().Lookup("pool2d");
ASSERT_TRUE(pool2d_converter);

framework::Scope scope;
std::unordered_set<std::string> parameters;
AnakinConvertValidation validator(parameters, scope);

// The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d_x", {1, 1, 17, 17});
validator.DeclOutputVar("pool2d_out", {1, 1, 17, 17});

// Prepare Op description
framework::OpDesc desc;
desc.SetType("pool2d");
desc.SetInput("X", {"pool2d_x"});
desc.SetOutput("Out", {"pool2d_out"});

std::vector<int> ksize({3, 3});
std::vector<int> strides({1, 1});
std::vector<int> paddings({1, 1});
std::string pooling_t = pool_type;

desc.SetAttr("pooling_type", pooling_t);
desc.SetAttr("ksize", ksize);
desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", true);

LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";

validator.Execute(1);
}

TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }

TEST(Pool2dOpConverter, max_ceil_test) { test_pool2d(false, true); }
TEST(Pool2dOpConverter, avg_ceil_test) { test_pool2d(false, true, "avg"); }
TEST(Pool2dOpConverter, avg_ceil_test2) { test_pool2d2(false, true, "avg"); }

} // namespace anakin
} // namespace inference
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/anakin/convert/ut_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class AnakinConvertValidation {
outputs.insert({output, tensor});
}

engine_->Execute(inputs, outputs);
engine_->Execute(inputs, outputs, stream_);
int i_output = 0;
for (const auto& output : op_desc_->OutputArgumentNames()) {
if (neglected_output.count(output)) continue;
Expand Down
48 changes: 31 additions & 17 deletions paddle/fluid/inference/anakin/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ namespace inference {
namespace anakin {

template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary)
AnakinEngine<TargetT, PrecisionType, RunType>::AnakinEngine(bool need_summary,
int device)
: graph_(new AnakinGraphT<TargetT, PrecisionType>()),
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {}
net_(new AnakinNetT<TargetT, PrecisionType, RunType>(need_summary)) {
device_ = device;
}

template <typename TargetT, Precision PrecisionType, OpRunType RunType>
AnakinEngine<TargetT, PrecisionType, RunType>::~AnakinEngine() {}
Expand Down Expand Up @@ -63,33 +66,44 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::AddOp(
template <typename TargetT, Precision PrecisionType, OpRunType RunType>
void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
const std::map<std::string, framework::LoDTensor *> &inputs,
const std::map<std::string, framework::LoDTensor *> &outputs) {
const std::map<std::string, framework::LoDTensor *> &outputs,
cudaStream_t stream) {
for (const auto &input : inputs) {
auto *tensor = input.second;
auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims());
auto fluid_input_shape = framework::vectorize2int(tensor->dims());

auto *anakin_input = net_->get_in(input.first);
auto anakin_input_shape = anakin_input->valid_shape();
PADDLE_ENFORCE(tensor->numel(), anakin_input_shape.count(),
"the fluid input size should be equal to anakin");
auto net_shape = anakin_input->shape();
if (tensor->numel() > net_shape.count()) {
graph_->Reshape(input.first, fluid_input_shape);
net_.reset(new AnakinNetT<TargetT, PrecisionType, RunType>(true));
net_->init(*graph_);
anakin_input = net_->get_in(input.first);
}

anakin_input->reshape(fluid_input_shape);
net_shape = anakin_input->shape();
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_input_shape);
anakin_input->copy_from(tmp_anakin_tensor);
net_shape);
anakin_input->share_from(tmp_anakin_tensor);
}

net_->prediction();
for (const auto &output : outputs) {
platform::CUDAPlace gpu_place(device_);
auto *tensor = output.second;
auto *data = tensor->data<float>();
auto shape = framework::vectorize2int(tensor->dims());
auto *anakin_output = net_->get_out(output.first);
auto *anakin_data = anakin_output->data();
auto anakin_output_shape = anakin_output->valid_shape();
PADDLE_ENFORCE(tensor->numel(), anakin_output_shape.count(),
"the fluid output size should be equal to anakin");
::anakin::saber::Tensor<TargetT> tmp_anakin_tensor(data, TargetT(), 0,
anakin_output_shape);
anakin_output->share_from(tmp_anakin_tensor);
tensor->Resize(framework::make_ddim(anakin_output_shape));
auto *fluid_data = tensor->mutable_data<float>(gpu_place);

memory::Copy(gpu_place, static_cast<void *>(fluid_data), gpu_place,
static_cast<void *>(anakin_data),
tensor->numel() * sizeof(float), stream);
}
net_->prediction();

cudaDeviceSynchronize();
}

Expand Down
Loading