Skip to content

Commit d4c9c06

Browse files
authored
fix: Replace EliminateExceptions lowering pass (#1859)
1 parent 7810df6 commit d4c9c06

File tree

4 files changed

+286
-2
lines changed

4 files changed

+286
-2
lines changed

core/lowering/lowering.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "torch/csrc/jit/passes/lower_graph.h"
1010
#include "torch/csrc/jit/passes/lower_tuples.h"
1111
#include "torch/csrc/jit/passes/peephole.h"
12-
#include "torch/csrc/jit/passes/remove_exceptions.h"
1312
#include "torch/csrc/jit/passes/remove_mutation.h"
1413

1514
#include "core/lowering/lowering.h"
@@ -105,7 +104,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
105104
torch::jit::InlineFunctionalGraphs(g);
106105
torch::jit::PeepholeOptimize(g, false);
107106
torch::jit::FuseLinear(g);
108-
torch::jit::EliminateExceptions(g);
107+
passes::EliminateExceptionsSafe(g);
109108
if (!lower_info.disable_cse) {
110109
torch::jit::EliminateCommonSubexpression(g);
111110
}

core/lowering/passes/exception_elimination.cpp

+66
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "torch/csrc/jit/ir/alias_analysis.h"
22
#include "torch/csrc/jit/jit_log.h"
3+
#include "torch/csrc/jit/passes/constant_pooling.h"
34
#include "torch/csrc/jit/passes/constant_propagation.h"
45
#include "torch/csrc/jit/passes/dead_code_elimination.h"
56
#include "torch/csrc/jit/passes/guard_elimination.h"
@@ -108,6 +109,71 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108109
}
109110
}
110111

112+
/*
113+
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
114+
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
115+
so as to not invalidate the IR in challenging cases, such as nested Ifs
116+
117+
Original Source from which it was adapted:
118+
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
119+
*/
120+
121+
bool certainlyThrows(Block* block) {
122+
// A block certainly throws an exception if it contains
123+
// the prim::RaiseException operation
124+
for (Node* n : block->nodes()) {
125+
if (n->kind() == prim::RaiseException) {
126+
return true;
127+
}
128+
}
129+
return false;
130+
}
131+
132+
void EliminateExceptionsSafe(Block* block) {
133+
auto graph = block->owningGraph();
134+
// Generate false and true constant placeholders
135+
Value* false_const = graph->insertConstant(IValue(false));
136+
Value* true_const = graph->insertConstant(IValue(true));
137+
138+
// For each prim::If node, if either block certainly throws an exception,
139+
// replace input conditional of the node input with the logical opposite
140+
for (Node* n : block->nodes()) {
141+
if (n->kind() == prim::If) {
142+
Block* true_block = n->blocks()[0];
143+
Block* false_block = n->blocks()[1];
144+
bool removed_exception = false;
145+
Value* input_value_replacement;
146+
147+
// If the block throws an exception, replace input with logical opposite
148+
if (certainlyThrows(true_block)) {
149+
removed_exception = true;
150+
input_value_replacement = false_const;
151+
} else if (certainlyThrows(false_block)) {
152+
removed_exception = true;
153+
input_value_replacement = true_const;
154+
}
155+
156+
// Log node and perform input replacement
157+
if (removed_exception) {
158+
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n));
159+
n->insertInput(0, input_value_replacement);
160+
n->removeInput(1);
161+
}
162+
}
163+
164+
// Inspect and replace all instances within subblocks of the current node
165+
for (Block* subblock : n->blocks()) {
166+
EliminateExceptionsSafe(subblock);
167+
}
168+
}
169+
}
170+
171+
void EliminateExceptionsSafe(std::shared_ptr<Graph>& graph) {
172+
EliminateExceptionsSafe(graph->block());
173+
ConstantPropagation(graph);
174+
ConstantPooling(graph);
175+
}
176+
111177
} // namespace passes
112178
} // namespace lowering
113179
} // namespace core

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2020
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2121
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2222
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
23+
void EliminateExceptionsSafe(std::shared_ptr<torch::jit::Graph>& graph);
2324
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2425
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
2526
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);

tests/core/lowering/test_exception_elimination_pass.cpp

+218
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#include "core/lowering/passes/passes.h"
22
#include "gtest/gtest.h"
3+
#include "tests/util/util.h"
34
#include "torch/csrc/jit/ir/irparser.h"
5+
#include "torch/csrc/jit/passes/canonicalize.h"
6+
#include "torch/csrc/jit/passes/constant_pooling.h"
7+
#include "torch/csrc/jit/passes/remove_exceptions.h"
48

59
TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
610
// parseIR does not support " = prim::If(%51)" with no return value
@@ -169,3 +173,217 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
169173
}
170174
EXPECT_EQ(1, if_count);
171175
}
176+
177+
TEST(LoweringPasses, EliminateExceptionsSafeIfBlock) {
178+
/*std::string source_graph = R"IR(
179+
graph(%x, %y):
180+
%dim : int = aten::dim(%x)
181+
%48 : int = prim::Constant[value=2]()
182+
%66 : bool = aten::eq(%48, %dim)
183+
%45 : str = prim::Constant[value="EXCEPTION"]()
184+
%4 : Tensor = prim::If(%66)
185+
block0():
186+
= prim::RaiseException(%45)
187+
-> (%x)
188+
block1():
189+
%res = aten::mul(%x, %y)
190+
-> (%res)
191+
return (%4))IR";*/
192+
193+
std::string target_graph = R"IR(
194+
graph(%x : Tensor,
195+
%y : Tensor):
196+
%6 : Tensor = aten::mul(%x, %y)
197+
return (%6))IR";
198+
199+
// Construct graph via manual commands, to avoid IR parsing issues with
200+
// unassigned variables (such as prim::RaiseException)
201+
auto g = std::make_shared<torch::jit::Graph>();
202+
auto x = g->insertInput(0, "x");
203+
auto y = g->insertInput(1, "y");
204+
auto none_const_val = g->insertConstant(torch::jit::IValue());
205+
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
206+
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
207+
g->appendNode(x_dims);
208+
x_dims->output()->setType(torch::jit::IntType::get());
209+
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
210+
g->appendNode(eq);
211+
eq->output()->setType(torch::jit::BoolType::get());
212+
torch::jit::IValue except("EXCEPTION");
213+
auto except_val = g->insertConstant(except);
214+
215+
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
216+
auto if_block0 = if_node->addBlock();
217+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
218+
if_block0->appendNode(exception_node);
219+
if_block0->registerOutput(x);
220+
221+
auto if_block1 = if_node->addBlock();
222+
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
223+
if_block1->appendNode(sum_node);
224+
if_block1->registerOutput(sum_node->output());
225+
226+
g->insertNode(if_node);
227+
g->registerOutput(if_node->output());
228+
229+
// Apply lowering pass and canonicalization to the graph
230+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
231+
g = torch::jit::Canonicalize(g, false);
232+
233+
auto tg = std::make_shared<torch::jit::Graph>();
234+
torch::jit::parseIR(target_graph, tg.get());
235+
236+
torch::jit::ConstantPooling(tg);
237+
tg = torch::jit::Canonicalize(tg, false);
238+
239+
// Validate identical graphs after pooling constants and canonicalizing
240+
ASSERT_TRUE((tg->toString() == g->toString()));
241+
}
242+
243+
TEST(LoweringPasses, EliminateExceptionsSafeElseBlock) {
244+
/*std::string source_graph = R"IR(
245+
graph(%x, %y):
246+
%dim : int = aten::dim(%x)
247+
%48 : int = prim::Constant[value=2]()
248+
%66 : bool = aten::eq(%48, %dim)
249+
%45 : str = prim::Constant[value="EXCEPTION"]()
250+
%4 : Tensor = prim::If(%66)
251+
block0():
252+
%res = aten::matmul(%x, %y)
253+
-> (%res)
254+
block1():
255+
= prim::RaiseException(%45)
256+
-> (%x)
257+
return (%4))IR";*/
258+
259+
std::string target_graph = R"IR(
260+
graph(%x : Tensor,
261+
%y : Tensor):
262+
%6 : Tensor = aten::matmul(%x, %y)
263+
return (%6))IR";
264+
265+
// Construct graph via manual commands, to avoid IR parsing issues with
266+
// unassigned variables (such as prim::RaiseException)
267+
auto g = std::make_shared<torch::jit::Graph>();
268+
auto x = g->insertInput(0, "x");
269+
auto y = g->insertInput(1, "y");
270+
auto none_const_val = g->insertConstant(torch::jit::IValue());
271+
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
272+
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
273+
g->appendNode(x_dims);
274+
x_dims->output()->setType(torch::jit::IntType::get());
275+
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
276+
g->appendNode(eq);
277+
eq->output()->setType(torch::jit::BoolType::get());
278+
torch::jit::IValue except("EXCEPTION");
279+
auto except_val = g->insertConstant(except);
280+
281+
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
282+
auto if_block0 = if_node->addBlock();
283+
auto sum_node = g->create(torch::jit::aten::matmul, {x, y}, 1);
284+
if_block0->appendNode(sum_node);
285+
if_block0->registerOutput(sum_node->output());
286+
287+
auto if_block1 = if_node->addBlock();
288+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
289+
if_block1->appendNode(exception_node);
290+
if_block1->registerOutput(x);
291+
292+
g->insertNode(if_node);
293+
g->registerOutput(if_node->output());
294+
295+
// Apply lowering pass and canonicalization to the graph
296+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
297+
g = torch::jit::Canonicalize(g, false);
298+
299+
auto tg = std::make_shared<torch::jit::Graph>();
300+
torch::jit::parseIR(target_graph, tg.get());
301+
302+
torch::jit::ConstantPooling(tg);
303+
tg = torch::jit::Canonicalize(tg, false);
304+
305+
// Validate identical graphs after pooling constants and canonicalizing
306+
ASSERT_TRUE((tg->toString() == g->toString()));
307+
}
308+
309+
TEST(LoweringPasses, EliminateExceptionsSafeNestedIfBlock) {
310+
/*std::string source_graph = R"IR(
311+
graph(%x, %y):
312+
%false : bool = prim::Constant[value=0]()
313+
%dim : int = aten::dim(%x)
314+
%48 : int = prim::Constant[value=2]()
315+
%66 : bool = aten::eq(%48, %dim)
316+
%45 : str = prim::Constant[value="EXCEPTION"]()
317+
%4 : Tensor = prim::If(%66)
318+
block0():
319+
%45 : str = prim::Constant[value="EXCEPTION"]()
320+
= prim::If(%false)
321+
block0():
322+
-> ()
323+
block1():
324+
= prim::RaiseException(%45)
325+
-> ()
326+
= prim::RaiseException(%45)
327+
-> (%x)
328+
block1():
329+
%res = aten::mul(%x, %y)
330+
-> (%res)
331+
return (%4))IR";*/
332+
333+
std::string target_graph = R"IR(
334+
graph(%x : Tensor,
335+
%y : Tensor):
336+
%6 : Tensor = aten::mul(%x, %y)
337+
return (%6))IR";
338+
339+
// Construct graph via manual commands, to avoid IR parsing issues with
340+
// unassigned variables (such as prim::RaiseException)
341+
auto g = std::make_shared<torch::jit::Graph>();
342+
auto x = g->insertInput(0, "x");
343+
auto y = g->insertInput(1, "y");
344+
auto none_const_val = g->insertConstant(torch::jit::IValue());
345+
auto false_const_val = g->insertConstant(torch::jit::IValue(false));
346+
auto two_const_val = g->insertConstant(torch::jit::IValue(2));
347+
auto x_dims = g->create(torch::jit::aten::dim, {x}, 1);
348+
g->appendNode(x_dims);
349+
x_dims->output()->setType(torch::jit::IntType::get());
350+
auto eq = g->create(torch::jit::aten::eq, {two_const_val, x_dims->output()}, 1);
351+
g->appendNode(eq);
352+
eq->output()->setType(torch::jit::BoolType::get());
353+
torch::jit::IValue except("EXCEPTION");
354+
auto except_val = g->insertConstant(except);
355+
356+
// Construct nested-If substructure in graph
357+
auto if_node = g->create(torch::jit::prim::If, {eq->output()}, 1);
358+
auto if_block0 = if_node->addBlock();
359+
auto if_if_node = g->create(torch::jit::prim::If, {false_const_val}, 0);
360+
if_block0->appendNode(if_if_node);
361+
/* auto if_if_block0 = */ if_if_node->addBlock();
362+
auto if_if_block1 = if_if_node->addBlock();
363+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
364+
if_if_block1->appendNode(exception_node);
365+
auto exception_node_2 = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
366+
if_block0->appendNode(exception_node_2);
367+
if_block0->registerOutput(x);
368+
369+
auto if_block1 = if_node->addBlock();
370+
auto sum_node = g->create(torch::jit::aten::mul, {x, y}, 1);
371+
if_block1->appendNode(sum_node);
372+
if_block1->registerOutput(sum_node->output());
373+
374+
g->insertNode(if_node);
375+
g->registerOutput(if_node->output());
376+
377+
// Apply lowering pass and canonicalization to the graph
378+
torch_tensorrt::core::lowering::passes::EliminateExceptionsSafe(g);
379+
g = torch::jit::Canonicalize(g, false);
380+
381+
auto tg = std::make_shared<torch::jit::Graph>();
382+
torch::jit::parseIR(target_graph, tg.get());
383+
384+
torch::jit::ConstantPooling(tg);
385+
tg = torch::jit::Canonicalize(tg, false);
386+
387+
// Validate identical graphs after pooling constants and canonicalizing
388+
ASSERT_TRUE((tg->toString() == g->toString()));
389+
}

0 commit comments

Comments
 (0)