Skip to content

Commit

Permalink
Arm backend batchnorm2d fixes and tests (#3369)
Browse files Browse the repository at this point in the history
Summary:
Batchnorm2d did not handle weight and bias

Pull Request resolved: #3369

Reviewed By: cccclai

Differential Revision: D56820624

Pulled By: digantdesai

fbshipit-source-id: 17f53f9ea7fcb2bb2d398d0625cea278bc144112
  • Loading branch information
zingo authored and facebook-github-bot committed May 7, 2024
1 parent 69d2a84 commit a82db7c
Show file tree
Hide file tree
Showing 3 changed files with 821 additions and 40 deletions.
132 changes: 116 additions & 16 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,60 +42,160 @@ def define_node(
permute_memory_to_nhwc: bool,
) -> None:
# Decompose batch norm into sequence
(activations, _, _, running_mean, running_var, momentum, epsilon) = inputs
(activations, weights, bias, running_mean, running_var, momentum, epsilon) = (
inputs
)

input_dtype = activations.dtype

assert (
0.1 == momentum.number
), "Expected 0.1 momentum, not currently encoded into TOSA"

# %op1 = tosa.SUB(%x, %bmean)
# %op2 = tosa.ADD(%variance, %epsilon_const)
# %output = (%x - %E[x]) / SQRT( %Var[x] + %epsilon ) * %gamma + %beta
# e.g.
# %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const ) * %weights + %bias
# ->
# %op1 = tosa.SUB(%activations, %running_mean)
# %op2 = tosa.ADD(%running_var, %epsilon_const)
# %op3 = tosa.RSQRT(%op2)
# %op4 = tosa.MUL(%op1, %op3)
# %op5 = tosa.MUL(%op4, %weight)
# %op5 = tosa.MUL(%op4, %weights)
# %output = tosa.ADD(%op5, %bias)

# Reshape mean to match rank of activations
mean_reshaped_res = promote_shape(
mean_reshaped = promote_shape(
tosa_graph,
running_mean,
self.augment_shape_rank(running_mean, permute_memory_to_nhwc),
input_dtype,
)

# Subtract mean
int1 = tosa_graph.addIntermediate(output.shape, input_dtype)
# %op1 = tosa.SUB(%activations, %running_mean)
op1 = tosa_graph.addIntermediate(output.shape, input_dtype)
tosa_graph.addOperator(
TosaOp.Op().SUB,
[activations.name, mean_reshaped_res.name],
[int1.name],
[activations.name, mean_reshaped.name],
[op1.name],
)
# Adding eplison to variance
# %op2 = tosa.ADD(%running_var, %epsilon_const)
epsilon_const = tosa_graph.addConst([1], input_dtype, [epsilon.number])
int2 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
op2 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
tosa_graph.addOperator(
TosaOp.Op().ADD,
[running_var.name, epsilon_const.name],
[int2.name],
[op2.name],
)
# Push downward the variance
int3 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
tosa_graph.addOperator(TosaOp.Op().RSQRT, [int2.name], [int3.name])
# %op3 = tosa.RSQRT(%op2)
op3 = tosa_graph.addIntermediate(running_var.shape, input_dtype)
tosa_graph.addOperator(TosaOp.Op().RSQRT, [op2.name], [op3.name])

# Reshape variable to match rank of activations
var_reshaped_res = promote_shape(
op3_reshaped = promote_shape(
tosa_graph,
int3,
op3,
self.augment_shape_rank(running_var, permute_memory_to_nhwc),
input_dtype,
)

# Handle non existing weights and bias
if not weights.name and not bias.name:
# Multiply shifted activations with reciprocal variance
# %output = tosa.MUL(%op1, %op3) e.g. Now we have %output = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const )
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [output.name], attr_mul
)
return
else:
# Multiply shifted activations with reciprocal variance
# %op4 = tosa.MUL(%op1, %op3)
op4 = tosa_graph.addIntermediate(output.shape, input_dtype)
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
TosaOp.Op().MUL, [op1.name, op3_reshaped.name], [op4.name], attr_mul
)

# Now we have %op4 = (%activations - %running_mean) / SQRT( %running_var + %epsilon_const )

if weights.name and not bias.name:
# Handle only weights but no bias

# Reshape weights to match rank of activations
weights_reshaped = promote_shape(
tosa_graph,
weights,
self.augment_shape_rank(weights, permute_memory_to_nhwc),
input_dtype,
)

# %output = tosa.MUL(%op4, %weights)
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
TosaOp.Op().MUL,
[op4.name, weights_reshaped.name],
[output.name],
attr_mul,
)
return

if not weights.name and bias.name:
# Handle only bias but no weights

# Reshape bias to match rank of activations
bias_reshaped = promote_shape(
tosa_graph,
bias,
self.augment_shape_rank(bias, permute_memory_to_nhwc),
input_dtype,
)

# %output = tosa.ADD(%op4, %bias)
tosa_graph.addOperator(
TosaOp.Op().ADD,
[op4.name, bias_reshaped.name],
[output.name],
)
return

# We have both weights and bias

# Reshape weights to match rank of activations
weights_reshaped = promote_shape(
tosa_graph,
weights,
self.augment_shape_rank(weights, permute_memory_to_nhwc),
input_dtype,
)

# %op5 = tosa.MUL(%op4, %weights)
op5 = tosa_graph.addIntermediate(output.shape, input_dtype)
attr_mul = ts.TosaSerializerAttribute()
attr_mul.MulAttribute(0)
tosa_graph.addOperator(
TosaOp.Op().MUL,
[op4.name, weights_reshaped.name],
[op5.name],
attr_mul,
)

# Multiple shifted activations with reciprocal variance
# Reshape bias to match rank of activations
bias_reshaped = promote_shape(
tosa_graph,
bias,
self.augment_shape_rank(bias, permute_memory_to_nhwc),
input_dtype,
)

# %output = tosa.ADD(%op5, %bias)
tosa_graph.addOperator(
TosaOp.Op().MUL, [int1.name, var_reshaped_res.name], [output.name], attr_mul
TosaOp.Op().ADD,
[op5.name, bias_reshaped.name],
[output.name],
)
Loading

0 comments on commit a82db7c

Please sign in to comment.