@@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
143
143
auto convert_info = cfg.convert_info ;
144
144
auto partitioning_info = cfg.partitioning_info ;
145
145
146
- // Any nonzero block size is valid if full compilation to TRT is desired
147
- if (expect_full_compilation) {
148
- partitioning_info.min_block_size = 1 ;
149
- }
150
-
151
146
auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
152
147
partitioning_ctx.input_types_map = first_use_types;
153
148
154
149
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
155
150
// TODO: Combine this within partition call
156
151
partitioning::populateInputIValues (&partitioning_ctx);
157
152
158
- partitioning::partition (&partitioning_ctx);
153
+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
159
154
160
155
for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
161
156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
@@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
197
192
if (expect_full_compilation) {
198
193
for (auto torch_node : seg_block.block ()->nodes ()) {
199
194
if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
200
- LOG_ERROR (
201
- " Full compilation specified but node " << torch_node->kind ().toQualString ()
202
- << " was executed in Torch." );
195
+ TORCHTRT_THROW_ERROR (
196
+ " Full compilation specified but node "
197
+ << *torch_node
198
+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199
+ << " Try recompiling with require_full_compilation=False." );
203
200
}
204
201
}
205
202
}
@@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
209
206
// If full compilation is expected, cannot have more than 2 Torch segments
210
207
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211
208
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
212
- LOG_ERROR (
213
- " Full compilation specified but number of torch segments was "
214
- << num_torch_segments << " and number of trt segments was " << num_trt_segments
215
- << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
209
+ TORCHTRT_THROW_ERROR (
210
+ " Full compilation was requested but unable to convert all operations to TensorRT."
211
+ << " Try recompiling with require_full_compilation=False." );
216
212
}
217
213
}
218
214
@@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
224
220
std::shared_ptr<torch::jit::Graph>& g,
225
221
ir::StaticParams& static_params,
226
222
ir::CollectionTypeMap& first_use_type_map,
227
- bool expect_full_compilation = false ) {
223
+ bool requires_collection_handling = false ) {
228
224
cfg.convert_info .collection_input_spec_map =
229
225
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
230
226
cfg.partitioning_info .collection_input_spec_map =
@@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
259
255
" Cannot infer input type from calcuations in graph for input "
260
256
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
261
257
spec[i].dtype = at::kFloat ;
262
- } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || expect_full_compilation )) {
258
+ } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || requires_collection_handling )) {
263
259
if (!est_type_opt[i]) {
264
260
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
265
261
std::stringstream ss;
@@ -352,10 +348,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352
348
// whether full compilation can be expected
353
349
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
354
350
auto outputIsCollection = conversion::OutputIsCollection (g->block ());
355
- auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
351
+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
356
352
357
353
// Extract map of IValue to DType
358
- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
354
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
359
355
360
356
// Check whether any of the input types are Long
361
357
bool user_requested_long = false ;
@@ -377,13 +373,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
377
373
}
378
374
379
375
if ((cfg.partitioning_info .enabled &&
380
- (!( cfg.lower_info .forced_fallback_modules .size () == 0 &&
381
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
382
- outputIsCollection || user_requested_long)) ||
383
- nearly_full_compilation ) {
376
+ (cfg.lower_info .forced_fallback_modules .size () != 0 ||
377
+ cfg.partitioning_info .forced_fallback_operators .size () != 0 || ! isBlockConvertible || outputIsCollection ||
378
+ user_requested_long)) ||
379
+ requires_collection_handling ) {
384
380
// If the model is fully-compilable and the user has specified full compilation, run partitioning
385
381
// to generate collection-processing code in Torch
386
- auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
382
+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
383
+
387
384
auto graph_and_mapping =
388
385
BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
389
386
new_g = graph_and_mapping.first ;
0 commit comments