Skip to content

Commit 5f325ec

Browse files
committed
Merge branch 'fix_loop_fallback' into monai_segresnet
2 parents 32f0dac + 6f3ec49 commit 5f325ec

File tree

1 file changed

+18
-30
lines changed

1 file changed

+18
-30
lines changed

core/partitioning/partitioning.cpp

+18-30
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
4747
}
4848
}
4949

50+
// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
51+
bool checkLoopEvaluatable(torch::jit::Node* n) {
52+
bool compile_to_trt = true;
53+
for (auto bn : n->blocks()[0]->nodes()) {
54+
if (bn->kind() == torch::jit::prim::Loop) {
55+
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
56+
} else if (bn->kind() == torch::jit::prim::If) {
57+
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
58+
} else {
59+
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
60+
}
61+
}
62+
return compile_to_trt;
63+
}
64+
5065
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
5166
// we use a map to indicate the reason why it's fallback to torch
5267
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
@@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
5974
continue;
6075
}
6176

62-
if (!conversion::OpSupported(n)) {
77+
if (n->kind() == torch::jit::prim::Loop && checkLoopEvaluatable(n)) {
78+
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kCONVERT);
79+
} else if (!conversion::OpSupported(n)) {
6380
// If the op is not supported by the conversion phase it should run in PyTorch
6481
ctx->setNodeExecutorDecision(n, NodeExecutorDecision::kUNSUPPORTED);
6582
} else if (ctx->forced_fallback_ops.find(n->kind().toQualString()) != ctx->forced_fallback_ops.end()) {
@@ -336,21 +353,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
336353
return;
337354
}
338355

339-
// Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
340-
bool checkLoopEvaluatable(torch::jit::Node* n) {
341-
bool compile_to_trt = true;
342-
for (auto bn : n->blocks()[0]->nodes()) {
343-
if (bn->kind() == torch::jit::prim::Loop) {
344-
compile_to_trt = compile_to_trt && checkLoopEvaluatable(bn);
345-
} else if (bn->kind() == torch::jit::prim::If) {
346-
compile_to_trt = compile_to_trt && containNonTensorOutputs(bn);
347-
} else {
348-
compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime(bn);
349-
}
350-
}
351-
return compile_to_trt;
352-
}
353-
354356
void finalizeNewBlock(
355357
PartitionedGraph& g,
356358
SegmentedBlock::SegmentedBlockTarget kind,
@@ -499,20 +501,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
499501
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, cond_node);
500502
segmented_blocks.back().do_not_merge(true);
501503
continue;
502-
} else if (n->kind() == torch::jit::prim::Loop) {
503-
if (!in_prog_pyt_blk_nodes.empty()) {
504-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
505-
cur_pyt_nodes_uses.clear();
506-
}
507-
if (checkLoopEvaluatable(n)) {
508-
in_prog_trt_blk_nodes.push_back(n);
509-
cur_trt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());
510-
} else {
511-
auto loop_node = std::vector<torch::jit::Node*>{n};
512-
finalizeNewBlock(segmented_blocks, SegmentedBlock::kTorch, loop_node);
513-
segmented_blocks.back().do_not_merge(true);
514-
}
515-
continue;
516504
}
517505
in_prog_pyt_blk_nodes.push_back(n);
518506
cur_pyt_nodes_uses.insert(dependent_nodes.begin(), dependent_nodes.end());

0 commit comments

Comments
 (0)