Skip to content

Commit 91306c9

Browse files
committed
fix: Add test case, move config condition
- Add test case to elicit behavior where full compilation is requested but TRT engine size falls below default `min_block_size=3` - Move `min_block_size` condition to narrow scope
1 parent 5ab8e17 commit 91306c9

File tree

4 files changed

+66
-24
lines changed

4 files changed

+66
-24
lines changed

core/compiler.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
143143
auto convert_info = cfg.convert_info;
144144
auto partitioning_info = cfg.partitioning_info;
145145

146-
// Any nonzero block size is valid if full compilation to TRT is desired
147-
if (expect_full_compilation) {
148-
partitioning_info.min_block_size = 1;
149-
}
150-
151146
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
152147
partitioning_ctx.input_types_map = first_use_types;
153148

154149
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
155150
// TODO: Combine this within partition call
156151
partitioning::populateInputIValues(&partitioning_ctx);
157152

158-
partitioning::partition(&partitioning_ctx);
153+
partitioning::partition(&partitioning_ctx, expect_full_compilation);
159154

160155
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
161156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
@@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
197192
if (expect_full_compilation) {
198193
for (auto torch_node : seg_block.block()->nodes()) {
199194
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) {
200-
LOG_ERROR(
201-
"Full compilation specified but node " << torch_node->kind().toQualString()
202-
<< " was executed in Torch.");
195+
TORCHTRT_THROW_ERROR(
196+
"Full compilation specified but node "
197+
<< *torch_node
198+
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+
<< " Try recompiling with require_full_compilation=False.");
203200
}
204201
}
205202
}
@@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
209206
// If full compilation is expected, cannot have more than 2 Torch segments
210207
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211208
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
212-
LOG_ERROR(
213-
"Full compilation specified but number of torch segments was "
214-
<< num_torch_segments << " and number of trt segments was " << num_trt_segments
215-
<< ". Was expecting at most 2 Torch segments and 1 TRT segment.");
209+
TORCHTRT_THROW_ERROR(
210+
"Full compilation was requested but unable to convert all operations to TensorRT."
211+
<< " Try recompiling with require_full_compilation=False.");
216212
}
217213
}
218214

@@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
224220
std::shared_ptr<torch::jit::Graph>& g,
225221
ir::StaticParams& static_params,
226222
ir::CollectionTypeMap& first_use_type_map,
227-
bool expect_full_compilation = false) {
223+
bool requires_collection_handling = false) {
228224
cfg.convert_info.collection_input_spec_map =
229225
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
230226
cfg.partitioning_info.collection_input_spec_map =
@@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
259255
"Cannot infer input type from calcuations in graph for input "
260256
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
261257
spec[i].dtype = at::kFloat;
262-
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || expect_full_compilation)) {
258+
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) {
263259
if (!est_type_opt[i]) {
264260
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
265261
std::stringstream ss;
@@ -352,10 +348,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352348
// whether full compilation can be expected
353349
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
354350
auto outputIsCollection = conversion::OutputIsCollection(g->block());
355-
auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
351+
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
356352

357353
// Extract map of IValue to DType
358-
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, nearly_full_compilation);
354+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling);
359355

360356
// Check whether any of the input types are Long
361357
bool user_requested_long = false;
@@ -377,13 +373,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
377373
}
378374

379375
if ((cfg.partitioning_info.enabled &&
380-
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
381-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
382-
outputIsCollection || user_requested_long)) ||
383-
nearly_full_compilation) {
376+
(cfg.lower_info.forced_fallback_modules.size() != 0 ||
377+
cfg.partitioning_info.forced_fallback_operators.size() != 0 || !isBlockConvertible || outputIsCollection ||
378+
user_requested_long)) ||
379+
requires_collection_handling) {
384380
// If the model is fully-compilable and the user has specified full compilation, run partitioning
385381
// to generate collection-processing code in Torch
386-
auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info.enabled);
382+
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled);
383+
387384
auto graph_and_mapping =
388385
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
389386
new_g = graph_and_mapping.first;

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) {
564564
}
565565
}
566566

567-
void partition(PartitioningCtx* ctx) {
567+
void partition(PartitioningCtx* ctx, bool expect_full_compilation) {
568+
// If full compilation is expected, overwrite minimum block size
569+
// Any nonzero block size is valid if full compilation to TRT is desired
570+
// Override the default min_block_size to ensure all TRT-supported operations are
571+
// executed in TRT, regardless of the size of the graph
572+
if (expect_full_compilation) {
573+
// If minimum block size is different from the default, the user must have specified it
574+
if (ctx->settings.min_block_size != 3) {
575+
LOG_WARNING(
576+
"Detected user-specified min_block_size with require_full_compilation=True "
577+
<< "disregarding min_block_size.");
578+
}
579+
ctx->settings.min_block_size = 1;
580+
}
581+
568582
LOG_DEBUG(ctx->settings);
569583

570584
// Go through all the blocks to do the partitioning

core/partitioning/partitioning.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block);
4848

4949
GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block);
5050

51-
void partition(PartitioningCtx* ctx);
51+
void partition(PartitioningCtx* ctx, bool expect_full_compilation = false);
5252

5353
} // namespace partitioning
5454
} // namespace core

tests/py/api/test_e2e_behavior.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,37 @@ def forward(self, x, y, z):
146146
trt_output, torch_output
147147
), "Found differing output formatting between Torch-TRT and Torch"
148148

149+
def test_tuple_output_with_full_compilation(self):
150+
class Sample(torch.nn.Module):
151+
def __init__(self):
152+
super(Sample, self).__init__()
153+
154+
def forward(self, x, y):
155+
a = x + y
156+
return (a,)
157+
158+
self.model = Sample().eval().to("cuda")
159+
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
160+
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
161+
scripted_mod = torch.jit.script(self.model)
162+
163+
inputs = [
164+
torchtrt.Input((5, 5), dtype=torch.float),
165+
torchtrt.Input((5, 5), dtype=torch.float),
166+
]
167+
168+
trt_mod = torchtrt.ts.compile(
169+
scripted_mod,
170+
inputs=inputs,
171+
require_full_compilation=True,
172+
enabled_precisions={torch.float, torch.half},
173+
)
174+
trt_output = trt_mod(self.input_1, self.input_2)
175+
torch_output = self.model(self.input_1, self.input_2)
176+
assert same_output_format(
177+
trt_output, torch_output
178+
), "Found differing output formatting between Torch-TRT and Torch"
179+
149180

150181
if __name__ == "__main__":
151182
unittest.main()

0 commit comments

Comments
 (0)