Skip to content

Commit

Permalink
Making sure there is a kernel implementation for the constant on the …
Browse files Browse the repository at this point in the history
…device

before replacement.
Change: 119079254
  • Loading branch information
keveman authored and tensorflower-gardener committed Apr 5, 2016
1 parent 84476b2 commit 37baf7a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 9 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1214,13 +1214,15 @@ tf_cc_test(
name = "common_runtime/constant_folding_test",
size = "small",
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
deps = [
":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
":framework",
":framework_internal",
":gpu_runtime",
":lib",
":lib_internal",
":ops",
Expand All @@ -1230,6 +1232,7 @@ tf_cc_test(
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/core/kernels:bcast_ops",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:matmul_op",
"//third_party/eigen3",
Expand Down
32 changes: 23 additions & 9 deletions tensorflow/core/common_runtime/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
Expand Down Expand Up @@ -226,6 +227,8 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
// constraint, do not replace it.
// 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
// constraint, do not replace it.
// 3) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
Expand Down Expand Up @@ -255,12 +258,26 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
}
string node_name = n->name();
Node* constant_node;
TF_CHECK_OK(NodeBuilder(strings::StrCat(graph->NewName(node_name), "__cf__",
UniqueConstantId()),
"Const")
.Attr("dtype", constant.dtype())
.Attr("value", constant)
.Finalize(graph, &constant_node));
auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
"__cf__", UniqueConstantId()),
"Const")
.Attr("dtype", constant.dtype())
.Attr("value", constant);
NodeDef def;
if (!builder.Finalize(&def).ok()) {
return false;
}
const KernelDef* kdef;
if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) {
return false;
}

VLOG(1) << "Replacing " << tensor.first->DebugString()
<< " :: " << tensor.second << " with a constant";

if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
Expand Down Expand Up @@ -388,9 +405,6 @@ bool DoConstantFolding(const ConstantFoldingOptions& opts,
if (!s.ok() || is_dead) {
return c > 0;
}
VLOG(1) << "Replacing " << tensors_to_replace[c].first->DebugString()
<< " :: " << tensors_to_replace[c].second << " with constant "
<< output.DebugString();
if (ReplaceTensorWithConstant(graph, partition_device,
tensors_to_replace[c], output)) {
++num_nodes_replaced;
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/core/common_runtime/constant_folding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,5 +187,40 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0});
}

TEST_F(ConstantFoldingTest, TestNoReplaceOnGPU) {
#if GOOGLE_CUDA
Device* device = nullptr;
std::vector<Device*> devices;
DeviceFactory::GetFactory(DEVICE_GPU)
->CreateDevices(SessionOptions{}, "", &devices);
if (devices.size() > 0) {
device = devices[0];
}
if (!device) {
// Don't run the test if not GPUs found.
return;
}
Reset();
Graph* g = g_.get();
Node* s0 = Constant<float>({42.0f}, {1});
g->AddControlEdge(g->source_node(), s0);
Node* cast = test::graph::Cast(g, s0, DT_BFLOAT16);
Node* send = test::graph::Send(g, cast, "cast", "sender", 0, "receiver");

g->AddControlEdge(send, g->sink_node());

// No ops should be replaced, as there is no kernel for BFLOAT16 on GPU.
EXPECT_FALSE(DoConstantFolding(ConstantFoldingOptions{}, device, g));

// But constant folding should have replaced the cast op with a constant when
// running on CPU.
EXPECT_TRUE(DoConstantFolding(ConstantFoldingOptions{}, nullptr, g));

for (auto d : devices) {
delete d;
}
#endif // GOOGLE_CUDA
}

} // namespace
} // namespace tensorflow
3 changes: 3 additions & 0 deletions tensorflow/core/graph/node_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
: def_builder_(name, op_def) {}

NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
: def_builder_(def_builder) {}

NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
inputs_.emplace_back(src_node, src_index);
DataType dt;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/graph/node_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class NodeBuilder {
const OpRegistryInterface* op_registry = OpRegistry::Global());
NodeBuilder(StringPiece name, const OpDef* op_def);

// Create a NodeBuilder from an existing NodeDefBuilder.
NodeBuilder(const NodeDefBuilder& def_builder);

// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*

Expand Down

0 comments on commit 37baf7a

Please sign in to comment.