@@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
143143 auto convert_info = cfg.convert_info ;
144144 auto partitioning_info = cfg.partitioning_info ;
145145
146- // Any nonzero block size is valid if full compilation to TRT is desired
147- if (expect_full_compilation) {
148- partitioning_info.min_block_size = 1 ;
149- }
150-
151146 auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
152147 partitioning_ctx.input_types_map = first_use_types;
153148
154149 // Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
155150 // TODO: Combine this within partition call
156151 partitioning::populateInputIValues (&partitioning_ctx);
157152
158- partitioning::partition (&partitioning_ctx);
153+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
159154
160155 for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
161156 partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
@@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
197192 if (expect_full_compilation) {
198193 for (auto torch_node : seg_block.block ()->nodes ()) {
199194 if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
200- LOG_ERROR (
201- " Full compilation specified but node " << torch_node->kind ().toQualString ()
202- << " was executed in Torch." );
195+ TORCHTRT_THROW_ERROR (
196+ " Full compilation specified but node "
197+ << *torch_node
198+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199+ << " Try recompiling with require_full_compilation=False." );
203200 }
204201 }
205202 }
@@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
209206 // If full compilation is expected, cannot have more than 2 Torch segments
210207 // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211208 if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
212- LOG_ERROR (
213- " Full compilation specified but number of torch segments was "
214- << num_torch_segments << " and number of trt segments was " << num_trt_segments
215- << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
209+ TORCHTRT_THROW_ERROR (
210+ " Full compilation was requested but unable to convert all operations to TensorRT."
211+ << " Try recompiling with require_full_compilation=False." );
216212 }
217213 }
218214
@@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
224220 std::shared_ptr<torch::jit::Graph>& g,
225221 ir::StaticParams& static_params,
226222 ir::CollectionTypeMap& first_use_type_map,
227- bool expect_full_compilation = false ) {
223+ bool requires_collection_handling = false ) {
228224 cfg.convert_info .collection_input_spec_map =
229225 std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
230226 cfg.partitioning_info .collection_input_spec_map =
@@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
259255 " Cannot infer input type from calcuations in graph for input "
260256 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
261257 spec[i].dtype = at::kFloat ;
262- } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || expect_full_compilation )) {
258+ } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || requires_collection_handling )) {
263259 if (!est_type_opt[i]) {
264260 LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
265261 std::stringstream ss;
@@ -352,10 +348,10 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352348 // whether full compilation can be expected
353349 auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
354350 auto outputIsCollection = conversion::OutputIsCollection (g->block ());
355- auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
351+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
356352
357353 // Extract map of IValue to DType
358- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
354+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
359355
360356 // Check whether any of the input types are Long
361357 bool user_requested_long = false ;
@@ -380,10 +376,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
380376 (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
381377 cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
382378 outputIsCollection || user_requested_long)) ||
383- nearly_full_compilation ) {
379+ requires_collection_handling ) {
384380 // If the model is fully-compilable and the user has specified full compilation, run partitioning
385381 // to generate collection-processing code in Torch
386- auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
382+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
383+
387384 auto graph_and_mapping =
388385 BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
389386 new_g = graph_and_mapping.first ;
0 commit comments