From 37baf7a5fcfaffe98f1618db0f8ab5654871b151 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Tue, 5 Apr 2016 11:44:50 -0800 Subject: [PATCH] Making sure there is a kernel implementation for the constant on the device before replacement. Change: 119079254 --- tensorflow/core/BUILD | 3 ++ .../core/common_runtime/constant_folding.cc | 32 ++++++++++++----- .../common_runtime/constant_folding_test.cc | 35 +++++++++++++++++++ tensorflow/core/graph/node_builder.cc | 3 ++ tensorflow/core/graph/node_builder.h | 3 ++ 5 files changed, 67 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 0936f4c941d3ba..24451af96f68c9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1214,6 +1214,7 @@ tf_cc_test( name = "common_runtime/constant_folding_test", size = "small", linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), deps = [ ":core", ":core_cpu", @@ -1221,6 +1222,7 @@ tf_cc_test( ":direct_session_internal", ":framework", ":framework_internal", + ":gpu_runtime", ":lib", ":lib_internal", ":ops", @@ -1230,6 +1232,7 @@ tf_cc_test( ":testlib", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:bcast_ops", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:matmul_op", "//third_party/eigen3", diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index ba561479ddfc36..07f08c55771379 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" @@ -226,6 +227,8 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, // constraint, do not replace it. // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY // constraint, do not replace it. + // 3) If the constant op created does not have a kernel implementation + // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. @@ -255,12 +258,26 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, } string node_name = n->name(); Node* constant_node; - TF_CHECK_OK(NodeBuilder(strings::StrCat(graph->NewName(node_name), "__cf__", - UniqueConstantId()), - "Const") - .Attr("dtype", constant.dtype()) - .Attr("value", constant) - .Finalize(graph, &constant_node)); + auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name), + "__cf__", UniqueConstantId()), + "Const") + .Attr("dtype", constant.dtype()) + .Attr("value", constant); + NodeDef def; + if (!builder.Finalize(&def).ok()) { + return false; + } + const KernelDef* kdef; + if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) { + return false; + } + + VLOG(1) << "Replacing " << tensor.first->DebugString() + << " :: " << tensor.second << " with a constant"; + + if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { + return false; + } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); @@ -388,9 +405,6 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts, if (!s.ok() || is_dead) { return c > 0; } - VLOG(1) << "Replacing " << tensors_to_replace[c].first->DebugString() - << " :: " << tensors_to_replace[c].second << " with constant " - << output.DebugString(); if (ReplaceTensorWithConstant(graph, partition_device, tensors_to_replace[c], output)) { ++num_nodes_replaced; diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index bb367b71a64916..d9db399703ea5d 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -187,5 +187,40 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) { ExpectNodeEqual(*(b1_ident->in_nodes().begin()), {}, {0}); } +TEST_F(ConstantFoldingTest, TestNoReplaceOnGPU) { +#if GOOGLE_CUDA + Device* device = nullptr; + std::vector devices; + DeviceFactory::GetFactory(DEVICE_GPU) + ->CreateDevices(SessionOptions{}, "", &devices); + if (devices.size() > 0) { + device = devices[0]; + } + if (!device) { + // Don't run the test if not GPUs found. + return; + } + Reset(); + Graph* g = g_.get(); + Node* s0 = Constant({42.0f}, {1}); + g->AddControlEdge(g->source_node(), s0); + Node* cast = test::graph::Cast(g, s0, DT_BFLOAT16); + Node* send = test::graph::Send(g, cast, "cast", "sender", 0, "receiver"); + + g->AddControlEdge(send, g->sink_node()); + + // No ops should be replaced, as there is no kernel for BFLOAT16 on GPU. + EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, device, g)); + + // But constant folding should have replaced the cast op with a constant when + // running on CPU. + EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, g)); + + for (auto d : devices) { + delete d; + } +#endif // GOOGLE_CUDA +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index df59950a1610ac..acf5e00e832c69 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -41,6 +41,9 @@ NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) : def_builder_(name, op_def) {} +NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder) + : def_builder_(def_builder) {} + NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { inputs_.emplace_back(src_node, src_index); DataType dt; diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 50c41e222171d5..b1d6b84e38efc2 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -79,6 +79,9 @@ class NodeBuilder { const OpRegistryInterface* op_registry = OpRegistry::Global()); NodeBuilder(StringPiece name, const OpDef* op_def); + // Create a NodeBuilder from an existing NodeDefBuilder. + NodeBuilder(const NodeDefBuilder& def_builder); + // You must call one Input() function per input_arg in the Op, // *and in the same order as the input_args appear in the OpDef.*