Skip to content

Commit fbfedc8

Browse files
authored
Merge pull request #16116 from NHZlX/anakin_merge_develop
add trt update for this branch
2 parents e77dce4 + 59ac472 commit fbfedc8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1005
-535
lines changed

Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ RUN curl -s -q https://glide.sh/get | sh
7575
# and its size is only one-third of the official one.
7676
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
7777
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
78-
RUN wget -qO- http://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
79-
tar -xz -C /usr/local && \
78+
79+
RUN wget -q https://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \
80+
tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \
8081
cp -rf /usr/local/TensorRT/include /usr && \
8182
cp -rf /usr/local/TensorRT/lib /usr
8283

paddle/fluid/framework/ir/fuse_pass_base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include "paddle/fluid/framework/ir/graph.h"
1819
#include "paddle/fluid/framework/ir/pass.h"
1920
#include "paddle/fluid/framework/scope.h"
@@ -24,6 +25,10 @@ namespace ir {
2425

2526
static const char kParamScopeAttr[] = "__param_scope__";
2627
static const char kFuseStatisAttr[] = "__fuse_statis__";
28+
// When we use trt or other third_party lib, the parameters are managed by
29+
// the lib, but not the fluid. So we need to record them to avoid duplicate
30+
// allocation.
31+
static const char kRepetitiveParamAttr[] = "__repetitive_param__";
2732

2833
enum FuseOptions {
2934
DO_NOT_FUSE, // fusing will not be done

paddle/fluid/inference/analysis/argument.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@
2323

2424
#pragma once
2525

26+
#include <memory>
2627
#include <string>
28+
#include <unordered_map>
29+
#include <unordered_set>
2730
#include <vector>
31+
2832
#include "paddle/fluid/framework/ir/graph.h"
2933
#include "paddle/fluid/framework/program_desc.h"
3034
#include "paddle/fluid/framework/scope.h"
@@ -133,6 +137,8 @@ struct Argument {
133137
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
134138
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
135139
AnalysisConfig::Precision);
140+
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
141+
bool);
136142

137143
// Memory optimized related.
138144
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);

paddle/fluid/inference/analysis/helper.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ limitations under the License. */
1717
#include <sys/stat.h>
1818
#include <cstdio>
1919
#include <fstream>
20+
#include <memory>
2021
#include <set>
2122
#include <string>
2223
#include <typeindex>
2324
#include <unordered_map>
25+
#include <utility>
2426
#include <vector>
2527

2628
#include "paddle/fluid/framework/framework.pb.h"
@@ -217,6 +219,35 @@ static std::string GetTrtCalibTableData(const std::string &model_opt_cache_dir,
217219
return "";
218220
}
219221

222+
static std::string GetTrtEngineSerializedPath(const std::string &model_root,
223+
const std::string &engine_key) {
224+
return model_root + "/trt_serialized_" + engine_key;
225+
}
226+
227+
static std::string GetTrtEngineSerializedData(
228+
const std::string &model_opt_cache_dir, const std::string &engine_key) {
229+
std::string trt_serialized_path =
230+
GetTrtEngineSerializedPath(model_opt_cache_dir, engine_key);
231+
if (FileExists(trt_serialized_path)) {
232+
VLOG(3) << "Trt serialized file: " << trt_serialized_path
233+
<< "is found here";
234+
std::ifstream infile(trt_serialized_path, std::ios::in);
235+
std::stringstream buffer;
236+
buffer << infile.rdbuf();
237+
std::string trt_engine_serialized_data(buffer.str());
238+
return trt_engine_serialized_data;
239+
}
240+
return "";
241+
}
242+
243+
static void SaveTrtEngineSerializedDataToFile(
244+
const std::string &trt_serialized_path,
245+
const std::string &engine_serialized_data) {
246+
std::ofstream outfile(trt_serialized_path);
247+
outfile << engine_serialized_data;
248+
outfile.close();
249+
}
250+
220251
} // namespace analysis
221252
} // namespace inference
222253
} // namespace paddle

paddle/fluid/inference/analysis/ir_pass_manager.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ void IRPassManager::CreatePasses(Argument *argument,
8989
pass->Set(
9090
"model_opt_cache_dir",
9191
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
92+
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
93+
pass->Set("use_static_engine",
94+
new bool(argument->tensorrt_use_static_engine()));
9295
}
9396

9497
pre_pass = pass_name;

paddle/fluid/inference/analysis/ir_pass_manager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222

2323
#pragma once
2424

25+
#include <memory>
2526
#include <string>
27+
#include <unordered_set>
28+
#include <utility>
2629
#include <vector>
2730
#include "paddle/fluid/framework/ir/graph.h"
2831
#include "paddle/fluid/framework/ir/pass.h"

0 commit comments

Comments
 (0)