Skip to content

fea/init tensorrt engine #10003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 34 commits into from
Apr 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a3140d3
add tensorrt
Superjomn Apr 13, 2018
a60189f
set tensorrt on as default
Superjomn Apr 13, 2018
b95d819
add cudnn dependency
Superjomn Apr 13, 2018
87fc090
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fea/add…
Superjomn Apr 13, 2018
8dda580
nvtest
Superjomn Apr 13, 2018
92480b5
add tensorrt dynamic loader
Superjomn Apr 13, 2018
5891896
add tensorrt as dyload
Superjomn Apr 13, 2018
1b475b3
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fea/add…
Superjomn Apr 15, 2018
9d617b8
finish test
Superjomn Apr 15, 2018
0e8e85f
remove tensorrt.cmake
Superjomn Apr 15, 2018
63b6a74
fix pip upgrade pip error
Superjomn Apr 15, 2018
d492547
add flag definition for tensorrt_dir
Superjomn Apr 16, 2018
4f0a2ab
code clean
Superjomn Apr 16, 2018
e220226
add default so search path
Superjomn Apr 16, 2018
5132a2b
update
Superjomn Apr 16, 2018
dc23dc5
Merge branch 'fea/add_tensorrt' into fea/tensorrt_engine
Superjomn Apr 16, 2018
1fe9f63
change cmake config
Superjomn Apr 16, 2018
cf4f092
Merge branch 'fea/add_tensorrt' into fea/tensorrt_engine
Superjomn Apr 16, 2018
f1b5040
init
Superjomn Apr 17, 2018
9699574
init
Superjomn Apr 17, 2018
aa7ab53
update
Superjomn Apr 17, 2018
1d13858
finish coding
Superjomn Apr 18, 2018
4da8cbd
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into fea/ten…
Superjomn Apr 18, 2018
5463325
fix conflict on Dockerfile
Superjomn Apr 18, 2018
74ea1f6
add new get output apis
Superjomn Apr 18, 2018
610f290
format code
Superjomn Apr 18, 2018
f273eef
Update networks.py
Superjomn Apr 18, 2018
57c0ddb
add inference namespace
Superjomn Apr 19, 2018
25397ca
Merge branch 'fea/tensorrt_engine' of github.com:Superjomn/Paddle int…
Superjomn Apr 19, 2018
6d89b54
fix copyright
Superjomn Apr 19, 2018
97a34ac
engine add namespace
Superjomn Apr 19, 2018
5b8de3b
change according to review
Superjomn Apr 24, 2018
bbf19cb
wrap test
Superjomn Apr 24, 2018
4c0ce9d
add helper
Superjomn Apr 25, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions paddle/fluid/inference/engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/framework.pb.h"

namespace paddle {
namespace inference {

/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
* format. It can be used to optimize performance of computation sub-blocks, for
* example, break down the original block into sub-blocks and execute each
* sub-blocks in different engines.
*
* For example:
* When inference, the resnet50 model can put most of the model into subgraph
* and run it on a TensorRT engine.
*
* There are several engines such as TensorRT and other frameworks, so an
* EngineBase is put forward to give an unified interface for all the
* different engine implemention.
*/
class EngineBase {
public:
using DescType = ::paddle::framework::proto::BlockDesc;

// Build the model and do some preparation, for example, in TensorRT, run
// createInferBuilder, buildCudaEngine.
virtual void Build(const DescType& paddle_model) = 0;

// Execute the engine, that will run the inference network.
virtual void Execute(int batch_size) = 0;

virtual ~EngineBase() {}

}; // class EngineBase

} // namespace inference
} // namespace paddle
5 changes: 4 additions & 1 deletion paddle/fluid/inference/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
if(WITH_TESTING)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以不加if(WITH_TESTING),因为在nv_test里面会做判断。可以之后的PR修改。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
endif()
134 changes: 134 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

void TensorRTEngine::Build(const DescType& paddle_model) {
PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_);
}

TensorRTEngine::~TensorRTEngine() {
// clean buffer
for (auto& buffer : buffers_) {
if (buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
buffer = nullptr;
}
}
}

void TensorRTEngine::FreezeNetwork() {
PADDLE_ENFORCE(infer_builder_ != nullptr,
"Call InitNetwork first to initialize network.");
PADDLE_ENFORCE(infer_network_ != nullptr,
"Call InitNetwork first to initialize network.");
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);

infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

infer_context_.reset(infer_engine_->createExecutionContext());

// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size(), nullptr);
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
item.second = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
}
}

nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name);

PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
PADDLE_ENFORCE(input, "infer network add input %s failed", name);

buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim);
return input;
}

void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name);

auto* output = layer->getOutput(offset);
PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str());
infer_network_->markOutput(*output);
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
buffer_sizes_[name] = 0;
}

void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name);
}

void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
size_t max_size) {
// determine data size
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);

PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
cudaMemcpyDeviceToHost, *stream_));
}

void*& TensorRTEngine::buffer(const std::string& name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
return buffers_[slot_offset];
}

void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
void* buf = buffer(name);
PADDLE_ENFORCE_EQ(
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
}

} // namespace tensorrt
} // namespace inference
} // namespace paddle
144 changes: 144 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

engine.h可以改名为tensorrt_engine.h么?不然和inference/engine.h重名了,改名后更加清晰。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件只会在 include 时使用,类似

  • #include "paddle/fluid/inference/engine.h
  • #include "paddle/fluid/inference/tensorrt/engine.h

可以区分开


Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <NvInfer.h>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* TensorRT Engine.
*
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
*/
class TensorRTEngine : public EngineBase {
public:
// Weight is model parameter.
class Weight {
public:
Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
w_.type = dtype;
w_.values = value;
w_.count = num_elem;
}
const nvinfer1::Weights& get() { return w_; }

private:
nvinfer1::Weights w_;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight class放在TensorrtEngine class里合适么?这个class也能被convert class调用。


TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
logger_(logger) {}

virtual ~TensorRTEngine();

// TODO(Superjomn) implement it later when graph segmentation is supported.
virtual void Build(const DescType& paddle_model) override;

virtual void Execute(int batch_size) override;

// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the executation
// environment.
void FreezeNetwork();

// Add an input and set its name, data type and dimention.
nvinfer1::ITensor* DeclareInput(const std::string& name,
nvinfer1::DataType dtype,
const nvinfer1::Dims& dim);
// Set the offset-th output from a layer as the network's output, and set its
// name.
void DeclareOutput(const nvinfer1::ILayer* layer, int offset,
const std::string& name);

// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void*& buffer(const std::string& name);

// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void SetInputFromGPU(const std::string& name, void* data, size_t size);
// Get an output called name, the output of tensorrt is in GPU, so this method
// will just return the output's GPU memory address.
void* GetOutputInGPU(const std::string& name);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);

nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }

private:
// the max batch size
int max_batch_;
// the max memory size the engine uses
int max_workspace_;
cudaStream_t* stream_;
nvinfer1::ILogger& logger_;

std::vector<void*> buffers_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;

// TensorRT related internal members
template <typename T>
struct Destroyer {
void operator()(T* x) { x->destroy(); }
};
template <typename T>
using infer_ptr = std::unique_ptr<T, Destroyer<T>>;
infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_;
}; // class TensorRTEngine

// Add an layer__ into engine__ with args ARGS.
// For example:
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
//
// Reference
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
//
// will add a fully connected layer into the engine.
// TensorRT has too many layers, so that is not wise to add member functions for
// them, and an macro like this is more extensible when underlying TensorRT
// library add new layer supports.
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问这个宏定义可以去掉么?

  • 直接用原来的函数也很清晰;
  • 因为convert类里面也需要add不同的layer,那么convert类需要包含engine类的头文件,是不是不太合理?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个宏会

  • 提供统一add layer 的接口,而不需要为每种layer增加一个函数,比如 addFullyConnected


} // namespace tensorrt
} // namespace inference
} // namespace paddle
Loading