-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Allow full model compilation with collection outputs #1599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,7 +138,8 @@ partitioning::GraphAndMapping BuildHybridGraph( | |
torch::jit::Block* block, | ||
CompileSpec cfg, | ||
ir::StaticParams static_params, | ||
ir::CollectionTypeMap first_use_types) { | ||
ir::CollectionTypeMap first_use_types, | ||
bool expect_full_compilation = false) { | ||
auto convert_info = cfg.convert_info; | ||
auto partitioning_info = cfg.partitioning_info; | ||
|
||
|
@@ -149,17 +150,20 @@ partitioning::GraphAndMapping BuildHybridGraph( | |
// TODO: Combine this within partition call | ||
partitioning::populateInputIValues(&partitioning_ctx); | ||
|
||
partitioning::partition(&partitioning_ctx); | ||
partitioning::partition(&partitioning_ctx, expect_full_compilation); | ||
|
||
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) { | ||
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second; | ||
int num_torch_segments = 0; | ||
int num_trt_segments = 0; | ||
|
||
for (auto& seg_block : segmented_blocks) { | ||
LOG_INFO("Block segment:" << seg_block); | ||
std::ostringstream trt_engine_id; | ||
trt_engine_id << reinterpret_cast<const int*>(&seg_block); | ||
|
||
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { | ||
num_trt_segments++; | ||
auto inputs = seg_block.construct_inputs_spec(); | ||
// update the input ranges for each segments | ||
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); | ||
|
@@ -180,8 +184,32 @@ partitioning::GraphAndMapping BuildHybridGraph( | |
true); | ||
|
||
seg_block.update_graph(temp_g); | ||
} else { | ||
num_torch_segments++; | ||
|
||
// If full compilation is expected, ensure that all operators in Torch blocks are | ||
// for collections processing | ||
if (expect_full_compilation) { | ||
for (auto torch_node : seg_block.block()->nodes()) { | ||
if (partitioning::CollectionNodeKinds.find(torch_node->kind()) == partitioning::CollectionNodeKinds.end()) { | ||
TORCHTRT_THROW_ERROR( | ||
"Full compilation specified but node " | ||
<< *torch_node | ||
<< " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules." | ||
<< " Try recompiling with require_full_compilation=False."); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// If full compilation is expected, cannot have more than 2 Torch segments | ||
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment | ||
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) { | ||
TORCHTRT_THROW_ERROR( | ||
"Full compilation was requested but unable to convert all operations to TensorRT." | ||
<< " Try recompiling with require_full_compilation=False."); | ||
} | ||
gs-olive marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
return partitioning::stitch(&partitioning_ctx, block); | ||
|
@@ -191,7 +219,8 @@ ir::TypeMap MapInputsAndDetermineDTypes( | |
CompileSpec& cfg, | ||
std::shared_ptr<torch::jit::Graph>& g, | ||
ir::StaticParams& static_params, | ||
ir::CollectionTypeMap& first_use_type_map) { | ||
ir::CollectionTypeMap& first_use_type_map, | ||
bool requires_collection_handling = false) { | ||
cfg.convert_info.collection_input_spec_map = | ||
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params)); | ||
cfg.partitioning_info.collection_input_spec_map = | ||
|
@@ -226,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes( | |
"Cannot infer input type from calcuations in graph for input " | ||
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); | ||
spec[i].dtype = at::kFloat; | ||
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) { | ||
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || requires_collection_handling)) { | ||
if (!est_type_opt[i]) { | ||
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); | ||
std::stringstream ss; | ||
|
@@ -297,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: | |
return engine; | ||
} | ||
|
||
bool userRequestedFallback(CompileSpec& cfg) { | ||
return cfg.lower_info.forced_fallback_modules.size() != 0 || | ||
cfg.partitioning_info.forced_fallback_operators.size() != 0; | ||
} | ||
Comment on lines
+329
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added helper function to determine if the user's input specifications imply fallback |
||
|
||
torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) { | ||
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt"); | ||
|
||
|
@@ -315,8 +349,17 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) | |
// Infer the type of an input from the weights of the calculation | ||
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); | ||
|
||
// Determine if the block is convertible/has collection output, and based on the result, | ||
// whether full compilation can be expected | ||
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); | ||
auto outputIsCollection = conversion::OutputIsCollection(g->block()); | ||
auto requires_collection_handling = (isBlockConvertible && outputIsCollection); | ||
|
||
// Determine whether user specifications necessitate partitioning | ||
auto isFallbackRequested = userRequestedFallback(cfg); | ||
|
||
// Extract map of IValue to DType | ||
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); | ||
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, requires_collection_handling); | ||
|
||
// Check whether any of the input types are Long | ||
bool user_requested_long = false; | ||
|
@@ -330,20 +373,28 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) | |
user_requested_long &= (casts_inserted > 0); | ||
} | ||
|
||
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); | ||
auto outputIsCollection = conversion::OutputIsCollection(g->block()); | ||
if (cfg.partitioning_info.enabled && !user_requested_long && | ||
(cfg.lower_info.forced_fallback_modules.size() == 0 && | ||
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) && | ||
!outputIsCollection) { | ||
// Partitioning is required if: | ||
// 1. User requested some modules/operators fallback | ||
// 2. The block (graph) cannot be converted due to operator coverage | ||
// 3. The output of the graph is a collection | ||
// 4. The user requested a non-TRT data type input | ||
auto isPartitioningRequired = | ||
(isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long); | ||
Comment on lines
+376
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Coalesced partitioning logic for readability |
||
|
||
// The user did not require full compilation, but the model can be fully compiled | ||
if (cfg.partitioning_info.enabled && !isPartitioningRequired) { | ||
LOG_INFO("Skipping partitioning since model is fully supported"); | ||
} | ||
|
||
if (cfg.partitioning_info.enabled && | ||
(!(cfg.lower_info.forced_fallback_modules.size() == 0 && | ||
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) || | ||
outputIsCollection || user_requested_long)) { | ||
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types); | ||
// The user did not require full compilation, and the model can be fully compiled | ||
// or, the user required full compilation but the I/O of the graph use collections | ||
if ((cfg.partitioning_info.enabled && isPartitioningRequired) || requires_collection_handling) { | ||
// If the model is fully-compilable and the user has specified full compilation, run partitioning | ||
// to generate collection-processing code in Torch | ||
auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info.enabled); | ||
|
||
auto graph_and_mapping = | ||
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation); | ||
new_g = graph_and_mapping.first; | ||
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly | ||
for (size_t i = 0; i < new_g->inputs().size(); ++i) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -564,7 +564,21 @@ void populateInputIValues(PartitioningCtx* ctx) { | |
} | ||
} | ||
|
||
void partition(PartitioningCtx* ctx) { | ||
void partition(PartitioningCtx* ctx, bool expect_full_compilation) { | ||
// If full compilation is expected, overwrite minimum block size | ||
// Any nonzero block size is valid if full compilation to TRT is desired | ||
// Override the default min_block_size to ensure all TRT-supported operations are | ||
// executed in TRT, regardless of the size of the graph | ||
if (expect_full_compilation) { | ||
// If minimum block size is different from the default, the user must have specified it | ||
if (ctx->settings.min_block_size != 3) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create an issue to centralize defaults somewhere in the core There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if a user sets There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the user would not get a warning message in that case. We currently don't have a way of knowing whether the user inputs a value or not, since the defaults are not centralized. There is an issue #1644 to address this, but as of now, your statement is correct. Additionally, it is worth noting that prior to this PR, if a user specified |
||
LOG_WARNING( | ||
"Detected user-specified min_block_size with require_full_compilation=True " | ||
<< "disregarding min_block_size."); | ||
} | ||
ctx->settings.min_block_size = 1; | ||
} | ||
|
||
LOG_DEBUG(ctx->settings); | ||
|
||
// Go through all the blocks to do the partitioning | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have edge cases like 2 torch_segments for inputs/outputs? Does
merge_adjacent_segments_of_same_type
always merge them into one?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should not be a case where multiple Torch segments appear for inputs/outputs, since
merge_adjacent_segments_of_same_type
addresses this case, as you had mentioned. Since the tensors in question are inputs, it should not arise thatsegment.do_not_merge()
is True, since the only approved operators falling into these segments are for collection construction, and only theprim::If
orprim::Loop
operators can induce a non-merge situation.