Skip to content

Commit fcb6890

Browse files
authored
Merge pull request #1851 from pytorch/exp_size
feat: Wrap dynamic size handling in a compilation flag
2 parents ac3ab77 + ed7fd99 commit fcb6890

File tree

15 files changed

+90
-17
lines changed

15 files changed

+90
-17
lines changed

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct BuilderSettings {
2323
bool refit = false;
2424
bool debug = false;
2525
bool truncate_long_and_double = false;
26+
bool allow_shape_tensors = false;
2627
ir::Device device;
2728
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
2829
nvinfer1::IInt8Calibrator* calibrator = nullptr;

core/conversion/evaluators/aten.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,12 @@ auto aten_registrations TORCHTRT_UNUSED =
270270
if (tensor_var.isITensor()) {
271271
auto tensor = tensor_var.ITensor();
272272
if (ctx->input_is_dynamic) {
273-
return dynamic_size_layer(ctx, n, args);
273+
if (ctx->settings.allow_shape_tensors) {
274+
return dynamic_size_layer(ctx, n, args);
275+
} else {
276+
LOG_WARNING(
277+
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
278+
}
274279
}
275280
return util::toVec(tensor->getDimensions());
276281
} else if (tensor_var.IValue()->isTensor()) {
@@ -286,7 +291,12 @@ auto aten_registrations TORCHTRT_UNUSED =
286291
auto dim = args.at(n->input(1)).unwrapToInt();
287292
if (tensor_var.isITensor()) {
288293
if (ctx->input_is_dynamic) {
289-
return dynamic_size_layer(ctx, n, args);
294+
if (ctx->settings.allow_shape_tensors) {
295+
return dynamic_size_layer(ctx, n, args);
296+
} else {
297+
LOG_WARNING(
298+
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
299+
}
290300
}
291301
auto tensor = tensor_var.ITensor();
292302
auto dims = util::toVec(tensor->getDimensions());
@@ -605,7 +615,8 @@ auto aten_registrations TORCHTRT_UNUSED =
605615
.evaluator(
606616
{c10::Symbol::fromQualString("aten::numel"),
607617
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
608-
LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel");
618+
LOG_WARNING(
619+
"There may be undefined behavior using dynamic shape and aten::numel without setting allow_shape_tensors");
609620
auto tensor_var = args.at(n->input(0));
610621
if (tensor_var.isITensor()) {
611622
auto tensor = tensor_var.ITensor();

core/conversion/evaluators/eval_util.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ nvinfer1::ITensor* index_layer(
3232
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
3333
LOG_DEBUG("Using dynamic version of aten::size evaluator");
3434
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
35-
LOG_DEBUG("Input dimensions: " << in->getDimensions());
35+
auto input_dims = in->getDimensions();
36+
LOG_DEBUG("Input dimensions: " << input_dims);
37+
3638
auto shape_layer = ctx->net->addShape(*in);
3739
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
3840
auto shape_1d_tensor = shape_layer->getOutput(0);
@@ -44,15 +46,31 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
4446
dim = dim < 0 ? dim + maxDim : dim;
4547
LOG_DEBUG("Dimension to select: " << dim);
4648
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
47-
}
49+
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
4850

49-
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
51+
auto tensor_holder = TensorContainer();
52+
tensor_holder.hold_tensor(shape_1d_tensor);
53+
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
5054

51-
auto tensor_holder = TensorContainer();
52-
tensor_holder.hold_tensor(shape_1d_tensor);
53-
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
55+
return shape_1d_ivalue;
5456

55-
return shape_1d_ivalue;
57+
} else {
58+
auto input_size = c10::impl::GenericList(c10::AnyType::get());
59+
// Only express the dynamic dimension with a shape layer output.
60+
// The static dimensions are preserved in the input size.
61+
for (int32_t i = 0; i < input_dims.nbDims; i++) {
62+
if (input_dims.d[i] == -1) {
63+
auto dynamic_dim_tensor = index_layer(ctx, n, shape_1d_tensor, i);
64+
auto dynamic_dim_holder = TensorContainer();
65+
dynamic_dim_holder.hold_tensor(dynamic_dim_tensor);
66+
auto dynamic_dim_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(dynamic_dim_holder)));
67+
input_size.emplace_back(std::move(dynamic_dim_ivalue));
68+
} else {
69+
input_size.emplace_back(input_dims.d[i]);
70+
}
71+
}
72+
return c10::IValue(input_size);
73+
}
5674
}
5775

5876
int64_t normalizeIndex(int64_t idx, int64_t list_size) {

cpp/bin/torchtrtc/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ int main(int argc, char** argv) {
168168
"Truncate weights that are provided in 64bit to 32bit (Long, Double to Int, Float)",
169169
{"truncate", "truncate-long-double", "truncate-64bit"});
170170

171+
args::Flag allow_shape_tensors(
172+
parser,
173+
"allow-shape-tensors",
174+
"(Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT",
175+
{"allow-shape-tensors"});
176+
171177
args::Flag save_engine(
172178
parser,
173179
"save_engine",
@@ -443,6 +449,10 @@ int main(int argc, char** argv) {
443449
compile_settings.truncate_long_and_double = true;
444450
}
445451

452+
if (allow_shape_tensors) {
453+
compile_settings.allow_shape_tensors = true;
454+
}
455+
446456
torch::jit::Module mod;
447457
try {
448458
// Deserialize the ScriptModule from a file using torch::jit::load().

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,11 @@ struct CompileSpec {
791791
*/
792792
bool truncate_long_and_double = false;
793793

794+
/**
795+
* Allow shape tensors (from IShape layer) in the graph
796+
*/
797+
bool allow_shape_tensors = false;
798+
794799
/**
795800
* Target Device
796801
*/

cpp/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9090
internal.convert_info.engine_settings.refit = external.refit;
9191
internal.convert_info.engine_settings.debug = external.debug;
9292
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
93+
internal.convert_info.engine_settings.allow_shape_tensors = external.allow_shape_tensors;
9394
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9495
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9596
internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;

py/torch_tensorrt/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ void RegisterTRTCompileSpec() {
8484
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_global_dram_size);
8585
ADD_FIELD_GET_SET_REGISTRATION(
8686
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
87+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, allow_shape_tensors);
8788
}
8889

8990
struct TRTTSRegistrations {

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
373373
info.partitioning_info.truncate_long_and_double = truncate_long_and_double;
374374
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
375375
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
376+
info.convert_info.engine_settings.allow_shape_tensors = allow_shape_tensors;
376377

377378
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
378379
TORCHTRT_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater");
@@ -423,6 +424,7 @@ std::string CompileSpec::stringify() {
423424
ss << " \"DLA Local DRAM Size\": " << dla_local_dram_size << std::endl;
424425
ss << " \"DLA Global DRAM Size\": " << dla_global_dram_size << std::endl;
425426
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
427+
ss << " \"Allow Shape tensors\": " << allow_shape_tensors << std::endl;
426428
ss << " \"Torch Fallback\": " << torch_fallback.to_str();
427429
ss << "}";
428430
return ss.str();

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ struct CompileSpec : torch::CustomClassHolder {
167167
ADD_FIELD_GET_SET(dla_local_dram_size, int64_t);
168168
ADD_FIELD_GET_SET(dla_global_dram_size, int64_t);
169169
ADD_FIELD_GET_SET(truncate_long_and_double, bool);
170+
ADD_FIELD_GET_SET(allow_shape_tensors, bool);
170171
ADD_FIELD_GET_SET(device, Device);
171172
ADD_FIELD_GET_SET(torch_fallback, TorchFallback);
172173
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
@@ -180,6 +181,7 @@ struct CompileSpec : torch::CustomClassHolder {
180181
bool refit = false;
181182
bool debug = false;
182183
bool truncate_long_and_double = false;
184+
bool allow_shape_tensors = false;
183185
Device device;
184186
TorchFallback torch_fallback;
185187
EngineCapability capability = EngineCapability::kDEFAULT;

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ PYBIND11_MODULE(_C, m) {
371371
.def_readwrite("dla_local_dram_size", &CompileSpec::dla_local_dram_size)
372372
.def_readwrite("dla_global_dram_size", &CompileSpec::dla_global_dram_size)
373373
.def_readwrite("torch_fallback", &CompileSpec::torch_fallback)
374-
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);
374+
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double)
375+
.def_readwrite("allow_shape_tensors", &CompileSpec::allow_shape_tensors);
375376

376377
py::class_<TorchFallback>(ts_sub_mod, "TorchFallback")
377378
.def(py::init<>())

py/torch_tensorrt/ts/_compile_spec.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
298298
assert isinstance(compile_spec["debug"], bool)
299299
info.debug = compile_spec["debug"]
300300

301+
if "allow_shape_tensors" in compile_spec:
302+
assert isinstance(compile_spec["allow_shape_tensors"], bool)
303+
info.allow_shape_tensors = compile_spec["allow_shape_tensors"]
304+
301305
if "device" in compile_spec:
302306
info.device = _parse_device(compile_spec["device"])
303307

@@ -354,6 +358,7 @@ def TensorRTCompileSpec(
354358
dla_global_dram_size=536870912,
355359
truncate_long_and_double=False,
356360
calibrator=None,
361+
allow_shape_tensors=False,
357362
) -> torch.classes.tensorrt.CompileSpec:
358363
"""Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
359364
@@ -388,6 +393,7 @@ def TensorRTCompileSpec(
388393
workspace_size (int): Maximum size of workspace given to TensorRT
389394
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
390395
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
396+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
391397
392398
Returns:
393399
torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt``
@@ -410,6 +416,7 @@ def TensorRTCompileSpec(
410416
"dla_global_dram_size": dla_global_dram_size, # Host RAM used by DLA to store weights and metadata for execution
411417
"calibrator": calibrator,
412418
"truncate_long_and_double": truncate_long_and_double,
419+
"allow_shape_tensors": allow_shape_tensors,
413420
}
414421

415422
parsed_spec = _parse_compile_spec(compile_spec)
@@ -461,6 +468,7 @@ def TensorRTCompileSpec(
461468
backend_spec._set_dla_local_dram_size(parsed_spec.dla_local_dram_size)
462469
backend_spec._set_dla_global_dram_size(parsed_spec.dla_global_dram_size)
463470
backend_spec._set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
471+
backend_spec._set_allow_shape_tensors(parsed_spec.allow_shape_tensors)
464472
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
465473

466474
return backend_spec

py/torch_tensorrt/ts/_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def compile(
3131
min_block_size=3,
3232
torch_executed_ops=[],
3333
torch_executed_modules=[],
34+
allow_shape_tensors=False,
3435
) -> torch.jit.ScriptModule:
3536
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
3637
@@ -94,6 +95,7 @@ def compile(
9495
min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT
9596
torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
9697
torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True
98+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
9799
98100
Returns:
99101
torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT
@@ -131,6 +133,7 @@ def compile(
131133
"forced_fallback_modules": torch_executed_modules,
132134
"min_block_size": min_block_size,
133135
},
136+
"allow_shape_tensors": allow_shape_tensors,
134137
}
135138

136139
compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
@@ -156,6 +159,7 @@ def convert_method_to_trt_engine(
156159
dla_global_dram_size=536870912,
157160
truncate_long_and_double=False,
158161
calibrator=None,
162+
allow_shape_tensors=False,
159163
) -> bytearray:
160164
"""Convert a TorchScript module method to a serialized TensorRT engine
161165
@@ -214,6 +218,7 @@ def convert_method_to_trt_engine(
214218
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
215219
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
216220
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
221+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
217222
218223
Returns:
219224
bytearray: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
@@ -236,6 +241,7 @@ def convert_method_to_trt_engine(
236241
"workspace_size": workspace_size, # Maximum size of workspace given to TensorRT
237242
"calibrator": calibrator,
238243
"truncate_long_and_double": truncate_long_and_double,
244+
"allow_shape_tensors": allow_shape_tensors,
239245
}
240246

241247
engine_str = _C.convert_graph_to_trt_engine(

tests/cpp/test_dynamic_size.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ TEST(Converters, ATenResizeDynamicShapeCorrectly) {
2727

2828
auto trt_in = at::clone(in);
2929
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
30-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
30+
auto trt_results =
31+
torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true);
3132

3233
auto trt = trt_results[0].reshape(jit_results[0].sizes());
3334

@@ -53,7 +54,8 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) {
5354

5455
auto trt_in = at::clone(in);
5556
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
56-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
57+
auto trt_results =
58+
torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true);
5759

5860
auto trt = trt_results[0].reshape(jit_results[0].sizes());
5961

@@ -83,7 +85,8 @@ TEST(Converters, ATenResizeGetItemDynShapeCorrectly) {
8385

8486
auto trt_in = at::clone(in);
8587
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
86-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
88+
auto trt_results =
89+
torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true);
8790

8891
auto trt = trt_results[0].reshape(jit_results[0].sizes());
8992

@@ -115,7 +118,8 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
115118

116119
auto trt_in = at::clone(in);
117120
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
118-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
121+
auto trt_results =
122+
torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true, /*allow_shape_tensors=*/true);
119123

120124
auto trt = trt_results[0].reshape(jit_results[0].sizes());
121125

tests/util/run_graph_engine.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@ std::vector<at::Tensor> RunGraphEngineDynamic(
9494
std::shared_ptr<torch::jit::Graph>& g,
9595
core::ir::StaticParams& named_params,
9696
std::vector<at::Tensor> inputs,
97-
bool dynamic_batch) {
97+
bool dynamic_batch = false,
98+
bool allow_shape_tensors = false) {
9899
LOG_DEBUG("Running TRT version");
99100
auto var_ins = get_var_inputs(g->inputs(), named_params);
100101
auto in = core::ir::pair_input_vals_with_specs(var_ins, toInputsDynamic(inputs, dynamic_batch));
101102
auto info = core::conversion::ConversionInfo();
102103
info.inputs = std::move(in);
104+
info.engine_settings.allow_shape_tensors = allow_shape_tensors;
103105
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
104106
return RunEngine(eng, inputs);
105107
}

tests/util/util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ std::vector<at::Tensor> RunGraphEngineDynamic(
5757
std::shared_ptr<torch::jit::Graph>& g,
5858
core::ir::StaticParams& named_params,
5959
std::vector<at::Tensor> inputs,
60-
bool dynamic_batch = false);
60+
bool dynamic_batch = false,
61+
bool allow_shape_tensors = false);
6162

6263
// Run the forward method of a module and return results
6364
torch::jit::IValue RunModuleForward(torch::jit::Module& mod, std::vector<torch::jit::IValue> inputs);

0 commit comments

Comments
 (0)