Skip to content

Commit

Permalink
[JIT][SR] Introduce prim::IfThenElse (pytorch#72587)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72587

This pattern frequently appears in a few graphs:

```
%result = prim::If(%condition)
  block0():
    -> (%a)
  block1():
    -> (%b)
```

This is slow, particularly in static runtime. Static runtime creates memory planners/block runners for each sub-block, which eats up a lot of memory and introduces a lot of extra overhead for this relatively simple operation.

This diff introduces a new op that replaces nodes like the above with a single op meant to act like a ternary operator:

```
%result = prim::IfThenElse(%condition, %a, %b)
```

Test Plan: New unit tests

Reviewed By: eellison

Differential Revision: D34091789

fbshipit-source-id: eb6a8c460c39b4c019a1f4ab1f3f1e5b6edc400c
  • Loading branch information
Mike Iovine authored and facebook-github-bot committed Feb 17, 2022
1 parent b3a1923 commit 0f1b335
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 2 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ namespace c10 {
_(prim, With) \
_(prim, Enter) \
_(prim, Exit) \
_(prim, IfThenElse) \
_(aten, Bool) \
_(aten, Int) \
_(aten, FloatImplicit) \
Expand Down
16 changes: 16 additions & 0 deletions benchmarks/static_runtime/test_static_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2720,3 +2720,19 @@ TEST(StaticRuntime, ToList) {
)JIT";
testStaticRuntime(src, {at::randn({2, 2})});
}

TEST(StaticRuntime, IfThenElse) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%none: NoneType = prim::Constant()
%c: Tensor = prim::IfThenElse(%cond, %a, %b)
%d: Tensor = aten::clone(%c, %none)
return (%d)
)IR";

std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};

testStaticRuntime(src, args1);
testStaticRuntime(src, args2);
}
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ endif()

# Build the cpp gtest binary containing the cpp-only tests.
set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_add_if_then_else.cpp
${JIT_TEST_ROOT}/test_alias_analysis.cpp
${JIT_TEST_ROOT}/test_argument_spec.cpp
${JIT_TEST_ROOT}/test_autodiff.cpp
Expand Down
53 changes: 53 additions & 0 deletions test/cpp/jit/test_add_if_then_else.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <gtest/gtest.h>

#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/add_if_then_else.h>

namespace torch {
namespace jit {

TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%result: Tensor = prim::If(%cond)
block0():
-> (%a)
block1():
-> (%b)
return (%result)
)IR";

auto graph = std::make_shared<Graph>();
parseIR(src, graph.get());
EXPECT_TRUE(AddIfThenElseOp(graph));

testing::FileCheck()
.check_count("= prim::IfThenElse", 1, /*exactly*/ true)
->check_count("= prim::If", 0, /*exactly*/ true)
->run(*graph);
}

TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
const auto src = R"IR(
graph(%cond: bool, %a: Tensor, %b: Tensor):
%result1: Tensor, %result2: Tensor = prim::If(%cond)
block0():
-> (%a, %b)
block1():
-> (%b, %a)
return (%result1, %result2)
)IR";

auto graph = std::make_shared<Graph>();
parseIR(src, graph.get());
EXPECT_FALSE(AddIfThenElseOp(graph));

testing::FileCheck()
.check_count("= prim::IfThenElse", 0, /*exactly*/ true)
->check_count("= prim::If", 1, /*exactly*/ true)
->run(*graph);
}

} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/operator_upgraders/utils.cpp",
"torch/csrc/jit/operator_upgraders/upgraders.cpp",
"torch/csrc/jit/operator_upgraders/upgraders_entry.cpp",
"torch/csrc/jit/passes/add_if_then_else.cpp",
"torch/csrc/jit/passes/annotate_warns.cpp",
"torch/csrc/jit/passes/bailout_graph.cpp",
"torch/csrc/jit/passes/batch_mm.cpp",
Expand Down
55 changes: 55 additions & 0 deletions torch/csrc/jit/passes/add_if_then_else.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <torch/csrc/jit/passes/add_if_then_else.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>

namespace torch {
namespace jit {

namespace {

bool hasNoNodes(Block* block) {
auto nodes = block->nodes();
return nodes.begin() == nodes.end();
}

bool hasTrivialSubBlocks(Node* node) {
const auto blocks = node->blocks();
DCHECK_EQ(blocks.size(), 2);

return hasNoNodes(blocks[0]) && hasNoNodes(blocks[1]);
}

} // namespace

bool AddIfThenElseOp(std::shared_ptr<Graph>& graph) {
std::vector<Node*> to_replace;
DepthFirstGraphNodeIterator graph_it(graph);
for (auto* node = graph_it.next(); node != nullptr; node = graph_it.next()) {
if (node->kind() != prim::If) {
continue;
}
if (node->outputs().size() != 1) {
continue;
}
if (hasTrivialSubBlocks(node)) {
to_replace.push_back(node);
}
}

for (auto* node : to_replace) {
auto* if_then_else_node = graph->create(prim::IfThenElse, 1);
if_then_else_node->addInput(node->input());
auto blocks = node->blocks();
if_then_else_node->addInput(blocks[0]->return_node()->input());
if_then_else_node->addInput(blocks[1]->return_node()->input());

if_then_else_node->insertBefore(node);
if_then_else_node->output()->copyMetadata(node->output());

node->output()->replaceAllUsesWith(if_then_else_node->output());
node->destroy();
}
return !to_replace.empty();
}

} // namespace jit
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/add_if_then_else.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

TORCH_API bool AddIfThenElseOp(std::shared_ptr<Graph>& graph);

} // namespace jit
} // namespace torch
7 changes: 7 additions & 0 deletions torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <c10/util/irange.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/add_if_then_else.h>
#include <torch/csrc/jit/passes/bailout_graph.h>
#include <torch/csrc/jit/passes/batch_mm.h>
#include <torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h>
Expand Down Expand Up @@ -650,6 +651,7 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getOptimizedPlanFor(
// replaces a fallback graph inserted by
// specialize_autogradzero if one exists
replaceFallbackGraphWithFallbackFunction(copy->block());
runFinalOptimizations(copy);
GRAPH_DUMP("Optimized Graph: ", copy);
optimized_plan_ =
ExecutionPlan(copy, function_name_, *remaining_bailout_depth_);
Expand Down Expand Up @@ -749,5 +751,10 @@ void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction(
}
}

void ProfilingGraphExecutorImpl::runFinalOptimizations(
std::shared_ptr<Graph>& graph) {
AddIfThenElseOp(graph);
}

} // namespace jit
} // namespace torch
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/profiling_graph_executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase {
std::shared_ptr<Graph>& graph,
size_t remaining_depth);
void replaceFallbackGraphWithFallbackFunction(Block* b);
void runFinalOptimizations(std::shared_ptr<Graph>& graph);
std::unique_ptr<ProfilingRecord> pr_;
c10::optional<ExecutionPlan>
profiling_plan_; // plan to run in order to profiling the code
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,17 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, at::stack(inputs, dim));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"prim::IfThenElse(bool cond, Any(a) x, Any(b) y) -> Any(a|b)"),
[](Stack& stack) {
const auto cond = stack[stack.size() - 3].toBool();
stack[stack.size() - 3] =
std::move(stack[stack.size() - (cond ? 2 : 1)]);
stack.pop_back();
stack.pop_back();
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"),
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/runtime/static/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <caffe2/core/timer.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/add_if_then_else.h>
#include <torch/csrc/jit/passes/canonicalize.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
Expand Down Expand Up @@ -173,6 +174,7 @@ void OptimizeGraph(
UseVariadicGroupedAccessor(graph);
EliminateNoOps(
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
AddIfThenElseOp(graph);
GRAPH_DUMP("Final graph after optimizations: ", graph);
}

Expand Down Expand Up @@ -1846,8 +1848,9 @@ static bool checkNoMemoryOverlap(const at::Tensor& a, const at::Tensor& b) {
}

bool ProcessedNode::verify_no_memory_overlap(bool force_check) const {
const static std::array<c10::Symbol, 5> special_case_ops = {
const static std::array<c10::Symbol, 6> special_case_ops = {
fromQualString("prim::TypeCheck"),
fromQualString("prim::IfThenElse"),
fromQualString("static_runtime::select_tensor"),
fromQualString("static_runtime::VarTupleUnpack"),
fromQualString("static_runtime::dict_unpack"),
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/runtime/static/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ TORCH_API inline bool doesNotHeapAllocateWhenStoredInIValue(const Type& type) {
}

TORCH_API inline bool borrowsOutputs(c10::Symbol kind) {
static const std::array<c10::Symbol, 3> symbols_with_borrowed_outputs = {
static const std::array<c10::Symbol, 4> symbols_with_borrowed_outputs = {
c10::Symbol::fromQualString("static_runtime::select_tensor"),
c10::Symbol::fromQualString("static_runtime::dict_unpack"),
c10::Symbol::fromQualString("static_runtime::VarTupleUnpack"),
c10::Symbol::fromQualString("prim::IfThenElse"),
};
return std::find(
symbols_with_borrowed_outputs.begin(),
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/runtime/static/native_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,5 +946,17 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
};
});

// See [Borrowed IValue Outputs]
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::IfThenElse,
prim_IfThenElse,
[](Node*) -> SROperator {
return [](ProcessedNode* pnode) {
const auto condition = pnode->Input(0).toBool();
pnode->Output(0) = condition ? createBorrowedIValue(pnode->Input(1))
: createBorrowedIValue(pnode->Input(2));
};
});

} // namespace jit
} // namespace torch

0 comments on commit 0f1b335

Please sign in to comment.