Skip to content

GreaterEqual and Cast fusion #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ struct FusedBatchNorm {
int fused_batch_norm = kMissingIndex;
};

// GreaterEqual with cast
struct GreaterEqualWithCast {
GreaterEqualWithCast() = default;

int greater_equal = kMissingIndex;
int cast = kMissingIndex;
};

// FusedBatchNorm[$is_training] with fused side input and/or activation.
struct FusedBatchNormEx {
FusedBatchNormEx() = default;
Expand Down Expand Up @@ -1033,6 +1041,41 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
return false;
}

bool FindGreaterEqualWithCast(const RemapperContext& ctx, int node_index,
GreaterEqualWithCast* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
const auto* node_def = node_view->node();

if (!IsCast(*node_def) || HasControlFaninOrFanout(*node_view)) return false;

if (node_view->NumRegularFanins() != 1) return false;
const auto& regular_fanin_0 = node_view->GetRegularFanin(0);
const auto* greater_equal = regular_fanin_0.node_view();
const auto* greater_equal_node_def = greater_equal->node();
if (!IsGreaterEqual(*greater_equal_node_def) ||
HasControlFaninOrFanout(*greater_equal))
return false;

DataType dtype = GetDataTypeFromAttr(*greater_equal_node_def, "T");
DataType src_dtype = GetDataTypeFromAttr(*node_def, "SrcT");
DataType dst_dtype = GetDataTypeFromAttr(*node_def, "DstT");
#if defined(INTEL_MKL) && defined(ENABLE_INTEL_MKL_BFLOAT16)
if (dtype != DT_FLOAT && dtype != DT_BFLOAT16) return false;
#else
if (dtype != DT_FLOAT) return false;
#endif
if ((dtype != dst_dtype) || (src_dtype != DT_BOOL)) return false;

// Check that only one node consumes the 0-th output of a GreaterEqual.
if (!HasAtMostOneDataFanoutAtPort0(*greater_equal) ||
IsInPreserveSet(ctx, greater_equal_node_def))
return false;

matched->cast = node_index;
matched->greater_equal = regular_fanin_0.node_index();
return true;
}

void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";

Expand Down Expand Up @@ -1571,6 +1614,40 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
return Status::OK();
}

Status AddGreaterEqualWithCastNode(RemapperContext* ctx,
const GreaterEqualWithCast& matched,
std::vector<bool>* invalidated_nodes,
std::vector<bool>* nodes_to_delete) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& greater_equal = graph->node(matched.greater_equal);
const NodeDef& cast = graph->node(matched.cast);

VLOG(2) << "Fuse " << cast.op() << " with GreaterEqual:"
<< " cast=" << cast.name() << " invalidated="
<< " greater_equal=" << greater_equal.name();

// Replace GreaterEqual and Cast with GreaterEqualWithCast.
NodeDef fused_op;
fused_op.set_op("_GreaterEqualWithCast");
fused_op.set_name(cast.name());
fused_op.set_device(greater_equal.device());

fused_op.add_input(greater_equal.input(0));
fused_op.add_input(greater_equal.input(1));
(*fused_op.mutable_attr())["T"] = greater_equal.attr().at("T");

utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;
mutation->AddNode(std::move(fused_op), &status);
TF_RETURN_IF_ERROR(status);
TF_RETURN_IF_ERROR(mutation->Apply());

(*invalidated_nodes)[matched.greater_equal] = true;
(*invalidated_nodes)[matched.cast] = true;

return Status::OK();
}

Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) {
const GraphDef* graph = ctx->graph_view.graph();
const NodeDef& fused_node = graph->node(matched.fused_batch_norm);
Expand Down Expand Up @@ -2007,6 +2084,15 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
continue;
}

// Remap GreaterEqual+Cast into the GreaterEqualWithCast.
GreaterEqualWithCast greater_equal_with_cast;
if (allow_non_differentiable_rewrites &&
FindGreaterEqualWithCast(ctx, i, &greater_equal_with_cast)) {
TF_RETURN_IF_ERROR(AddGreaterEqualWithCastNode(
&ctx, greater_equal_with_cast, &invalidated_nodes, &nodes_to_delete));
continue;
}
}

// Remove invalidated nodes.
Expand Down
58 changes: 58 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,5 +851,63 @@ TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) {
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
}

TEST_F(RemapperTest, FuseGreaterEqualWithCast) {
using ::tensorflow::ops::Placeholder;

for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

const int num_channels = 24;

TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});

auto x = Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape({2, 8, 8, num_channels}));
auto y = Placeholder(s.WithOpName("y"), DT_FLOAT,
ops::Placeholder::Shape({2, 8, 8, num_channels}));

float epsilon = 0.1f;
auto ge = ops::GreaterEqual(s.WithOpName("greater_equal"), x, y);
auto cast = ops::Cast(s.WithOpName("cast"), ge.z, DT_FLOAT);
auto fetch = ops::Identity(s.WithOpName("fetch"), cast);

auto input1_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
auto input2_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});

GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"x", input1_t}, {"y", input2_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));

// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:CPU:0");
}

Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));

int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "cast") {
EXPECT_EQ(node.op(), "_GreaterEqualWithCast");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "x");
EXPECT_EQ(node.input(1), "y");
found++;
}
}
EXPECT_EQ(found, 1);

auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
}
}

} // namespace grappler
} // namespace tensorflow
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/cwise_op_greater_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License.
namespace tensorflow {
REGISTER9(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int32, int64, uint8, int8, int16, bfloat16);
REGISTER2(BinaryOp, CPU, "_GreaterEqualWithCast",
functor::greater_equal_with_cast, float, bfloat16);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER7(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int64, uint8, int8, int16);
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/cwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ limitations under the License.
#include <functional>
#include <type_traits>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace Eigen {
namespace internal {
Expand Down Expand Up @@ -1141,6 +1141,11 @@ struct greater : base<T, Eigen::internal::greater<T>, bool> {};
template <typename T>
struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};

template <typename T>
struct greater_equal_with_cast
: base<T, Eigen::internal::scalar_cmp_with_cast_op<
T, T, Eigen::internal::cmp_GE>> {};

template <typename T>
struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};

Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/ops/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,13 @@ REGISTER_OP("GreaterEqual").COMPARISON();

#undef COMPARISON

REGISTER_OP("_GreaterEqualWithCast")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);

// --------------------------------------------------------------------------

#define EQUALITY_COMPARISON() \
Expand Down