Skip to content

Commit

Permalink
Adds support for FusedBatchNormV3 in converter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 262588769
  • Loading branch information
Sachin Joglekar authored and tensorflower-gardener committed Aug 9, 2019
1 parent 0cdf225 commit 7f70878
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 11 deletions.
53 changes: 43 additions & 10 deletions tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,58 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>

// CHECK-LABEL: fusedBatchNorm
// CHECK:%cst = constant dense<1.000000e-03> : tensor<f32>
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// variance + epsilon
// CHECK: %0 = "tf.Add"(%arg4, %cst) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
// CHECK: %1 = "tf.Rsqrt"(%0) : (tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
// scale * rsqrt(variance + epsilon)
// CHECK: %2 = "tf.Mul"(%arg1, %1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
// x * scale * rsqrt(variance + epsilon)
// CHECK: %3 = "tf.Mul"(%arg0, %2) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
// mean * scale * rsqrt(variance + epsilon)
// CHECK: %4 = "tf.Mul"(%arg3, %2) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %5 = "tf.Sub"(%arg2, %4) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %6 = "tf.Add"(%3, %5) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])

// CHECK: %7:5 = "tf.FusedBatchNorm"(%6, %arg1, %arg2, %arg3, %arg4)
// CHECK: %8:5 = "tf.FusedBatchNorm"(%7#0, %arg1, %arg2, %arg3, %arg4)
// CHECK: %[[BATCHNORM1:.*]]:5 = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: {{.*}} = "tf.FusedBatchNorm"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}

func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Unsupported training
%1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Use other output
%2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)

return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>

// CHECK-LABEL: fusedBatchNormV3
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
// scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
// x * scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
// mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])

// CHECK: %[[BATCHNORM1:.*]]:6 = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: %[[BATCHNORM2:.*]]:6 = "tf.FusedBatchNormV3"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}

// CHECK-LABEL: fakeQuantForActivation
Expand Down
26 changes: 25 additions & 1 deletion tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def HasNoUse: Constraint<
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;

// Converts tf.FusedBatchNorm into a sequence of more primitive arithmetic
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
// (x - mean) * scale / sqrt(variance + epsilon) + offset
Expand Down Expand Up @@ -53,6 +53,30 @@ def : Pattern<
[(HasNoUse $root__1), (HasNoUse $root__2),
(HasNoUse $root__3), (HasNoUse $root__4)]>;

def : Pattern<
(TF_FusedBatchNormV3Op:$root
$x, $scale, $offset, $mean, $variance,
F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training),
[(TF_AddOp
(TF_MulOp
$x,
(TF_MulOp:$multiplier
$scale,
(TF_RsqrtOp
(TF_AddOp $variance,
(TF_ConstOp $epsilon))))),
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
// We already guaranteed that the last five results have no use so it does
// not matter what value we provide here for replacement.
/*batch_mean=*/(replaceWithValue $x),
/*batch_variance=*/(replaceWithValue $x),
/*reserve_space_1=*/(replaceWithValue $x),
/*reserve_space_2=*/(replaceWithValue $x),
/*reserve_space_3=*/(replaceWithValue $x)],
[(HasNoUse $root__1), (HasNoUse $root__2),
(HasNoUse $root__3), (HasNoUse $root__4),
(HasNoUse $root__5)]>;

// TODO(jpienaar): Move to opbase something more general.
def TFi32ElementsAttr : Attr<CPred<"$_self.isa<DenseIntElementsAttr>">,
"scalar int attribute"> {
Expand Down
33 changes: 33 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,39 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];
}

def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
let summary = "Batch normalization.";

let description = [{
Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
The size of 1D Tensors matches the dimension C of the 4D Tensors.
}];

let arguments = (ins
TensorOf<[BF16, F16, F32]>:$x,
F32Tensor:$scale,
F32Tensor:$offset,
F32Tensor:$mean,
F32Tensor:$variance,

DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon,
DefaultValuedAttr<TF_ConvnetDataFormatAttr, "NHWC">:$data_format,
DefaultValuedAttr<BoolAttr, "true">:$is_training
);

let results = (outs
TensorOf<[BF16, F16, F32]>:$y,
F32Tensor:$batch_mean,
F32Tensor:$batch_variance,
F32Tensor:$reserve_space_1,
F32Tensor:$reserve_space_2,
F32Tensor:$reserve_space_3
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
}

def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
let summary = "Gather slices from `params` according to `indices`.";

Expand Down

0 comments on commit 7f70878

Please sign in to comment.