Skip to content

Commit

Permalink
[TF:XLA] Don't compile functions that are marked "noinline".
Browse files Browse the repository at this point in the history
The underlying function mechanism uses LocalExecutor to call the function,
which interacts poorly with the LocalExecutor used by tf2xla to translate
the TF graph into XLA.
Change: 150268961
  • Loading branch information
tensorflower-gardener committed Mar 16, 2017
1 parent 3d489f2 commit b05a839
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
12 changes: 12 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
CHECK(fbody);
const FunctionDef& fdef = fbody->fdef;
bool noinline = false;
if (GetNodeAttr(AttrSlice(&fdef.attr()), "_noinline", &noinline).ok() &&
noinline) {
// The underlying mechanism that calls non-inlined functions uses
// LocalExecutor, which interacts poorly with the LocalExecutor used by
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
return false;
}

for (Node* node : fbody->graph->nodes()) {
if (node->IsSource() || node->IsSink()) continue;
Expand Down
15 changes: 12 additions & 3 deletions tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,20 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
}

TEST(XlaCompilationTest, FunctionCalls) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
FunctionDef compilable = FunctionDefHelper::Define(
"CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
{{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
*flib.add_function() =
FunctionDef uncompilable =
FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
{}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
FunctionDef noinline = compilable;
noinline.mutable_signature()->set_name("NoInlineFn");
AddAttr("_noinline", bool(true), noinline.mutable_attr());

FunctionDefLibrary flib;
*flib.add_function() = compilable;
*flib.add_function() = uncompilable;
*flib.add_function() = noinline;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);

std::unique_ptr<Graph> graph(new Graph(&flib_def));
Expand All @@ -202,6 +209,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
TF_EXPECT_OK(builder.ToGraph(graph.get()));
}

Expand All @@ -213,6 +221,7 @@ TEST(XlaCompilationTest, FunctionCalls) {
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
EXPECT_TRUE(clusters.find("E") == clusters.cend());
}

// Metadata-only operators such as Shape/Rank/Size may not be the root of a
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/tests/function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def Foo(a, b):
result = sess.run(call_f)
self.assertAllClose(result, expected, rtol=1e-3)

def testFunctionsNoInline(self):
# TODO(b/36139787): Re-enable this test when noinline works again.
def DISABLED_testFunctionsNoInline(self):

@function.Defun(dtypes.float32, noinline=True)
def TimesTwo(x):
Expand Down
12 changes: 7 additions & 5 deletions tensorflow/compiler/tests/jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,14 @@ def AddOnceReturnTwice(x):
# function (say, Bar) which is not inlined. When the compiler compiles
# Foo, it needs to symbolic execute Bar correctly regardless whether
# Bar is inlined or not.
#

# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
self._compare(
AddOnceReturnTwice, [np.array(
[[[0.5, -1.0]]], dtype=np.float32)],
noinline=True)
# self._compare(
# AddOnceReturnTwice, [np.array(
# [[[0.5, -1.0]]], dtype=np.float32)],
# noinline=True)

# Tests compiled=True and noinline=False.
self._compare(
AddOnceReturnTwice, [np.array(
Expand Down

0 comments on commit b05a839

Please sign in to comment.