Skip to content

Commit efe5376

Browse files
authored
Merge pull request tensorflow#8958 from rohan100jain/branch_152141388
Branch 152141388
2 parents 8908272 + 5ee21f2 commit efe5376

File tree

211 files changed

+5481
-4644
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

211 files changed

+5481
-4644
lines changed

tensorflow/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ filegroup(
202202
"//tensorflow/contrib/boosted_trees:all_files",
203203
"//tensorflow/contrib/boosted_trees/lib:all_files",
204204
"//tensorflow/contrib/boosted_trees/proto:all_files",
205+
"//tensorflow/contrib/boosted_trees/resources:all_files",
205206
"//tensorflow/contrib/cloud:all_files",
206207
"//tensorflow/contrib/cloud/kernels:all_files",
207208
"//tensorflow/contrib/compiler:all_files",
@@ -256,6 +257,7 @@ filegroup(
256257
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
257258
"//tensorflow/contrib/training:all_files",
258259
"//tensorflow/contrib/util:all_files",
260+
"//tensorflow/contrib/xla_tf_graph:all_files",
259261
"//tensorflow/core:all_files",
260262
"//tensorflow/core/debug:all_files",
261263
"//tensorflow/core/distributed_runtime:all_files",

tensorflow/compiler/aot/tests/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ genrule(
5151
"test_graph_tfgather.pb",
5252
"test_graph_tfmatmul.pb",
5353
"test_graph_tfmatmulandadd.pb",
54+
"test_graph_tffunction.pb",
5455
],
5556
cmd = "$(location :make_test_graphs) --out_dir $(@D)",
5657
tags = ["manual"],
@@ -114,6 +115,15 @@ tf_library(
114115
tags = ["manual"],
115116
)
116117

118+
tf_library(
119+
name = "test_graph_tffunction",
120+
testonly = 1,
121+
config = "test_graph_tffunction.config.pbtxt",
122+
cpp_class = "FunctionComp",
123+
graph = "test_graph_tffunction.pb",
124+
tags = ["manual"],
125+
)
126+
117127
cc_test(
118128
name = "tfcompile_test",
119129
srcs = ["tfcompile_test.cc"],
@@ -122,6 +132,7 @@ cc_test(
122132
":test_graph_tfadd",
123133
":test_graph_tfadd_with_ckpt",
124134
":test_graph_tfadd_with_ckpt_saver",
135+
":test_graph_tffunction",
125136
":test_graph_tfgather",
126137
":test_graph_tfmatmul",
127138
":test_graph_tfmatmulandadd",

tensorflow/compiler/aot/tests/make_test_graphs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.client import session
2626
from tensorflow.python.framework import constant_op
2727
from tensorflow.python.framework import dtypes
28+
from tensorflow.python.framework import function
2829
from tensorflow.python.framework import ops
2930
from tensorflow.python.ops import array_ops
3031
from tensorflow.python.ops import math_ops
@@ -95,6 +96,17 @@ def tfmatmulandadd(_):
9596
math_ops.add(x, y, name='x_y_sum')
9697

9798

99+
def tffunction(_):
100+
101+
@function.Defun(dtypes.int32, dtypes.int32)
102+
def test_func(a, b):
103+
return a + b
104+
105+
x = constant_op.constant([1], name='x_const')
106+
y = constant_op.constant([2], name='y_const')
107+
test_func(x, y, name='func_call') # pylint: disable=unexpected-keyword-arg
108+
109+
98110
def write_graph(build_graph, out_dir):
99111
"""Build a graph using build_graph and write it out."""
100112
g = ops.Graph()
@@ -112,6 +124,7 @@ def main(_):
112124
write_graph(tfgather, FLAGS.out_dir)
113125
write_graph(tfmatmul, FLAGS.out_dir)
114126
write_graph(tfmatmulandadd, FLAGS.out_dir)
127+
write_graph(tffunction, FLAGS.out_dir)
115128

116129

117130
if __name__ == '__main__':
@@ -121,7 +134,6 @@ def main(_):
121134
'--out_dir',
122135
type=str,
123136
default='',
124-
help='Output directory for graphs, checkpoints and savers.'
125-
)
137+
help='Output directory for graphs, checkpoints and savers.')
126138
FLAGS, unparsed = parser.parse_known_args()
127139
app.run(main=main, argv=[sys.argv[0]] + unparsed)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Text form of tensorflow.tfcompile.Config proto.
2+
feed {
3+
id { node_name: "x_const" }
4+
shape {
5+
dim { size: 1 }
6+
}
7+
}
8+
feed {
9+
id { node_name: "y_const" }
10+
shape {
11+
dim { size: 1 }
12+
}
13+
}
14+
fetch {
15+
id { node_name: "func_call" }
16+
}

tensorflow/compiler/aot/tests/tfcompile_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
2121
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
2222
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
23+
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
2324
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
2425
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
2526
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
@@ -376,6 +377,21 @@ TEST(TFCompileTest, MatMulAndAdd1) {
376377
}
377378
}
378379

380+
TEST(TFCompileTest, Function) {
381+
// The function is equivalent to an addition
382+
FunctionComp add_fn;
383+
EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
384+
EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
385+
386+
add_fn.arg0() = 1;
387+
add_fn.arg1() = 2;
388+
EXPECT_TRUE(add_fn.Run());
389+
EXPECT_EQ(add_fn.error_msg(), "");
390+
EXPECT_EQ(add_fn.result0(), 3);
391+
EXPECT_EQ(add_fn.result0_data()[0], 3);
392+
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
393+
}
394+
379395
} // namespace
380396
} // namespace tfcompile
381397
} // namespace tensorflow

tensorflow/compiler/jit/mark_for_compilation_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
5050
}
5151

5252
// Make sure we don't recurse infinitely on recursive functions.
53-
const int kMaxRecursionDepth = 5;
53+
const int kMaxRecursionDepth = 10;
5454

5555
bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
5656
int depth, FunctionLibraryRuntime* lib_runtime);

tensorflow/compiler/tests/randomized_tests.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,14 @@ TEST_F(OpTest, ZerosLike) {
23392339
});
23402340
}
23412341

2342+
TEST_F(OpTest, OnesLike) {
2343+
Repeatedly([this]() {
2344+
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
2345+
ExpectTfAndXlaOutputsAreClose(
2346+
OpTestBuilder("OnesLike").Input(RandomTensor(type)).Attr("T", type));
2347+
});
2348+
}
2349+
23422350
} // anonymous namespace
23432351
} // namespace tensorflow
23442352

tensorflow/compiler/tests/unary_ops_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ def testNumericOps(self):
257257
np.array([[4, 3], [2, 1]], dtype=dtype),
258258
expected=np.array([[0, 0], [0, 0]], dtype=dtype))
259259

260+
self._assertOpOutputMatchesExpected(
261+
array_ops.ones_like,
262+
np.array([[4, 3], [2, 1]], dtype=dtype),
263+
expected=np.array([[1, 1], [1, 1]], dtype=dtype))
264+
260265
def testLogicalOps(self):
261266
self._assertOpOutputMatchesExpected(
262267
math_ops.logical_not,

tensorflow/compiler/tf2xla/kernels/shape_op.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,5 +241,19 @@ class ZerosLikeOp : public XlaOpKernel {
241241

242242
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
243243

244+
class OnesLikeOp : public XlaOpKernel {
245+
public:
246+
explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
247+
248+
void Compile(XlaOpKernelContext* ctx) override {
249+
const TensorShape input_shape = ctx->InputShape(0);
250+
251+
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
252+
ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
253+
}
254+
};
255+
256+
REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
257+
244258
} // namespace
245259
} // namespace tensorflow

tensorflow/compiler/tf2xla/xla_compiler.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,15 @@ Status CheckSignature(const DataTypeVector& types,
5959

6060
XlaCompiler::XlaCompiler(XlaCompiler::Options options)
6161
: options_(std::move(options)),
62+
initialization_status_(Status::OK()),
6263
next_step_id_(1),
6364
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
64-
device_mgr_({device_}) {}
65+
device_mgr_({device_}) {
66+
if (options_.populate_resource_manager) {
67+
initialization_status_ =
68+
(*options_.populate_resource_manager)(device_->resource_manager());
69+
}
70+
}
6571

6672
XlaCompiler::~XlaCompiler() = default;
6773

@@ -379,6 +385,9 @@ Status XlaCompiler::CompileGraph(string const& name,
379385
CompilationResult* result) {
380386
VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
381387

388+
// Report the error here if initialization failed.
389+
TF_RETURN_IF_ERROR(initialization_status_);
390+
382391
xla::ComputationBuilder builder(client(), name);
383392
XlaContext* context =
384393
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,

0 commit comments

Comments
 (0)