Skip to content

Commit aa1f1e5

Browse files
committed
Support TensorRT reformat-free I/O
1 parent e66dd4c commit aa1f1e5

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

src/backends/tensorrt/plan_backend.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,12 +490,11 @@ PlanBackend::Context::InitializeInputBinding(
490490

491491
MemoryFormat fmt =
492492
ConvertTrtFmtToFmt(engine_->getBindingFormat(binding_index));
493-
if (fmt != MemoryFormat::LINEAR) {
493+
if (fmt == MemoryFormat::INVALID) {
494494
return Status(
495495
RequestStatusCode::INVALID_ARG,
496-
"unexpected tensor format " + MemoryFormat_Name(fmt) +
497-
" for input '" + input_name +
498-
"'. Only LINEAR memory format is supported at present.");
496+
"unexpected tensor format " + MemoryFormat_Name(fmt) + " for input '" +
497+
input_name + "'.");
499498
}
500499

501500
nvinfer1::Dims engine_dims = engine_->getBindingDimensions(binding_index);
@@ -507,7 +506,7 @@ PlanBackend::Context::InitializeInputBinding(
507506
if (!(is_control && is_dynamic_)) {
508507
RETURN_IF_ERROR(CompareDimsSupported(
509508
name_, input_name, engine_dims, model_config_dims, support_batching,
510-
is_dynamic_));
509+
is_dynamic_, fmt));
511510
} else {
512511
Status status = ValidateControlDimsDynamic(engine_dims, support_batching);
513512
if (!status.IsOk()) {
@@ -704,12 +703,11 @@ PlanBackend::Context::InitializeConfigOutputBindings(
704703

705704
MemoryFormat fmt =
706705
ConvertTrtFmtToFmt(engine_->getBindingFormat(binding_index));
707-
if (fmt != MemoryFormat::LINEAR) {
706+
if (fmt == MemoryFormat::INVALID) {
708707
return Status(
709708
RequestStatusCode::INVALID_ARG,
710709
"unexpected tensor format " + MemoryFormat_Name(fmt) +
711-
" for output '" + io.name() +
712-
"'. Only LINEAR memory format is supported at present.");
710+
" for output '" + io.name() + "'.");
713711
}
714712

715713
const DimsList& model_config_dims =
@@ -718,7 +716,7 @@ PlanBackend::Context::InitializeConfigOutputBindings(
718716
nvinfer1::Dims engine_dims = engine_->getBindingDimensions(binding_index);
719717
RETURN_IF_ERROR(CompareDimsSupported(
720718
name_, io.name(), engine_dims, model_config_dims, support_batching,
721-
is_dynamic_));
719+
is_dynamic_, fmt));
722720

723721
int64_t byte_size;
724722
if (!is_dynamic_) {

src/backends/tensorrt/plan_utils.cc

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ ConvertTrtFmtToFmt(nvinfer1::TensorFormat trt_fmt)
5656
case nvinfer1::TensorFormat::kCHW4:
5757
return MemoryFormat::CHW4;
5858
case nvinfer1::TensorFormat::kHWC8:
59-
return MemoryFormat::HCW8;
59+
return MemoryFormat::HWC8;
6060
case nvinfer1::TensorFormat::kCHW16:
6161
return MemoryFormat::CHW16;
6262
case nvinfer1::TensorFormat::kCHW32:
@@ -76,8 +76,8 @@ MemoryFormat_Name(MemoryFormat fmt)
7676
return "CHW2";
7777
case MemoryFormat::CHW4:
7878
return "CHW4";
79-
case MemoryFormat::HCW8:
80-
return "HCW8";
79+
case MemoryFormat::HWC8:
80+
return "HWC8";
8181
case MemoryFormat::CHW16:
8282
return "CHW16";
8383
case MemoryFormat::CHW32:
@@ -89,6 +89,36 @@ MemoryFormat_Name(MemoryFormat fmt)
8989
return "INVALID";
9090
}
9191

92+
int
93+
MemoryFormat_VectorSize(MemoryFormat fmt)
94+
{
95+
unsigned int vector_size = 1;
96+
switch(fmt) {
97+
case MemoryFormat::LINEAR:
98+
vector_size = 1;
99+
break;
100+
case MemoryFormat::CHW2:
101+
vector_size = 2;
102+
break;
103+
case MemoryFormat::CHW4:
104+
vector_size = 4;
105+
break;
106+
case MemoryFormat::HWC8:
107+
vector_size = 8;
108+
break;
109+
case MemoryFormat::CHW16:
110+
vector_size = 16;
111+
break;
112+
case MemoryFormat::CHW32:
113+
vector_size = 32;
114+
break;
115+
default:
116+
vector_size = 1; // In the default case, assume LINEAR
117+
break;
118+
}
119+
return vector_size;
120+
}
121+
92122
std::pair<bool, nvinfer1::DataType>
93123
ConvertDataTypeToTrtType(const DataType& dtype)
94124
{
@@ -132,7 +162,8 @@ Status
132162
CompareDimsSupported(
133163
const std::string& model_name, const std::string& binding_name,
134164
const nvinfer1::Dims& model_dims, const DimsList& dims,
135-
const bool supports_batching, const bool is_dynamic)
165+
const bool supports_batching, const bool is_dynamic,
166+
const MemoryFormat fmt)
136167
{
137168
// If the model configuration expects batching support in the model,
138169
// then the first dimension must be -1.
@@ -166,6 +197,15 @@ CompareDimsSupported(
166197
continue;
167198
}
168199

200+
// Pad channel dimension if necessary.
201+
if (i == dims.size() - 3) {
202+
int vector_size = MemoryFormat_VectorSize(fmt);
203+
if (vector_size > 1)
204+
{
205+
model_dim = model_dim + vector_size - ((model_dim + vector_size) % vector_size);
206+
}
207+
}
208+
169209
if (model_dim != dims[i]) {
170210
return Status(
171211
RequestStatusCode::INVALID_ARG,

src/backends/tensorrt/plan_utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ enum class MemoryFormat {
4141
// Four wide channel vectorized row major format.
4242
CHW4,
4343
// Eight channel format where C is padded to a multiple of 8.
44-
HCW8,
44+
HWC8,
4545
// Sixteen wide channel vectorized row major format.
4646
CHW16,
4747
// Thirty-two wide channel vectorized row major format.
@@ -53,6 +53,7 @@ enum class MemoryFormat {
5353
MemoryFormat ConvertTrtFmtToFmt(nvinfer1::TensorFormat trt_fmt);
5454

5555
const std::string MemoryFormat_Name(MemoryFormat fmt);
56+
int MemoryFormat_VectorSize(MemoryFormat fmt);
5657

5758
DataType ConvertTrtTypeToDataType(nvinfer1::DataType trt_type);
5859

@@ -72,7 +73,8 @@ Status ValidateDimension(
7273
Status CompareDimsSupported(
7374
const std::string& model_name, const std::string& tensor_name,
7475
const nvinfer1::Dims& model_dims, const DimsList& dims,
75-
const bool supports_batching, const bool is_dynamic);
76+
const bool supports_batching, const bool is_dynamic,
77+
const MemoryFormat fmt);
7678

7779
Status ValidateControlDimsDynamic(
7880
const nvinfer1::Dims& dims, const bool support_batching);

0 commit comments

Comments
 (0)