@@ -47,6 +47,21 @@ void setInputsOutputsConnectedNodes(PartitioningCtx* ctx, torch::jit::Block* blo
47
47
}
48
48
}
49
49
50
+ // Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
51
+ bool checkLoopEvaluatable (torch::jit::Node* n) {
52
+ bool compile_to_trt = true ;
53
+ for (auto bn : n->blocks ()[0 ]->nodes ()) {
54
+ if (bn->kind () == torch::jit::prim::Loop) {
55
+ compile_to_trt = compile_to_trt && checkLoopEvaluatable (bn);
56
+ } else if (bn->kind () == torch::jit::prim::If) {
57
+ compile_to_trt = compile_to_trt && containNonTensorOutputs (bn);
58
+ } else {
59
+ compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime (bn);
60
+ }
61
+ }
62
+ return compile_to_trt;
63
+ }
64
+
50
65
// Find and set all explicit fallback nodes (nodes that are unsupported or forced fallback)
51
66
// we use a map to indicate the reason why it's fallback to torch
52
67
// For any node that's not explicitly fallback, we set it to run in TensorRT for now
@@ -59,7 +74,9 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
59
74
continue ;
60
75
}
61
76
62
- if (!conversion::OpSupported (n)) {
77
+ if (n->kind () == torch::jit::prim::Loop && checkLoopEvaluatable (n)) {
78
+ ctx->setNodeExecutorDecision (n, NodeExecutorDecision::kCONVERT );
79
+ } else if (!conversion::OpSupported (n)) {
63
80
// If the op is not supported by the conversion phase it should run in PyTorch
64
81
ctx->setNodeExecutorDecision (n, NodeExecutorDecision::kUNSUPPORTED );
65
82
} else if (ctx->forced_fallback_ops .find (n->kind ().toQualString ()) != ctx->forced_fallback_ops .end ()) {
@@ -336,21 +353,6 @@ void registerSegmentsOutputs(PartitioningCtx* ctx, torch::jit::Block* block) {
336
353
return ;
337
354
}
338
355
339
- // Need to check if this makes sense might be a root cause of some issues of over aggressive fallback
340
- bool checkLoopEvaluatable (torch::jit::Node* n) {
341
- bool compile_to_trt = true ;
342
- for (auto bn : n->blocks ()[0 ]->nodes ()) {
343
- if (bn->kind () == torch::jit::prim::Loop) {
344
- compile_to_trt = compile_to_trt && checkLoopEvaluatable (bn);
345
- } else if (bn->kind () == torch::jit::prim::If) {
346
- compile_to_trt = compile_to_trt && containNonTensorOutputs (bn);
347
- } else {
348
- compile_to_trt = compile_to_trt && core::conversion::evaluators::shouldEvalAtConversionTime (bn);
349
- }
350
- }
351
- return compile_to_trt;
352
- }
353
-
354
356
void finalizeNewBlock (
355
357
PartitionedGraph& g,
356
358
SegmentedBlock::SegmentedBlockTarget kind,
@@ -499,20 +501,6 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
499
501
finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , cond_node);
500
502
segmented_blocks.back ().do_not_merge (true );
501
503
continue ;
502
- } else if (n->kind () == torch::jit::prim::Loop) {
503
- if (!in_prog_pyt_blk_nodes.empty ()) {
504
- finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , in_prog_pyt_blk_nodes);
505
- cur_pyt_nodes_uses.clear ();
506
- }
507
- if (checkLoopEvaluatable (n)) {
508
- in_prog_trt_blk_nodes.push_back (n);
509
- cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
510
- } else {
511
- auto loop_node = std::vector<torch::jit::Node*>{n};
512
- finalizeNewBlock (segmented_blocks, SegmentedBlock::kTorch , loop_node);
513
- segmented_blocks.back ().do_not_merge (true );
514
- }
515
- continue ;
516
504
}
517
505
in_prog_pyt_blk_nodes.push_back (n);
518
506
cur_pyt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
0 commit comments