Skip to content

Commit

Permalink
[microNPU][ETHOSU] Fix SoftMax legalization parameters (apache#15069)
Browse files Browse the repository at this point in the history
* [microNPU][ETHOSU] Fix Softmax activation parameters

Fix activation parameters for operations according to the values in Vela.

* fix legalization parameters

* Update test_legalize.py

* Update test_legalize.py
  • Loading branch information
Aleksei-grovety authored Jun 26, 2023
1 parent 904515b commit 28aead9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 67 deletions.
75 changes: 50 additions & 25 deletions python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
quant_min = -128
quant_max = 127

ifm = post.args[0]
ifm_dtype = ifm.checked_type.dtype
Expand Down Expand Up @@ -121,12 +123,14 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=1.0,
ofm_zero_point=127,
ofm_zero_point=quant_max,
ifm_channels=depth,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="LUT",
clip_min=-255,
clip_max=0,
)

# PASS 2 - SHR
Expand All @@ -147,8 +151,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
rounding_mode="NATURAL",
)

Expand All @@ -165,6 +169,9 @@ def callback(
ofm_channels=1,
upscale="NONE",
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 4 - CLZ
Expand All @@ -177,6 +184,9 @@ def callback(
ofm_scale=0.0,
ofm_zero_point=int(params.ifm.q_params.zero_point),
ofm_channels=1,
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 5 - Sub
Expand All @@ -196,6 +206,9 @@ def callback(
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 6 - Sub
Expand All @@ -215,6 +228,9 @@ def callback(
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 7 - SHL
Expand All @@ -229,13 +245,13 @@ def callback(
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=depth,
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 8 - Sub
Expand All @@ -255,6 +271,9 @@ def callback(
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 9 - SHL
Expand All @@ -274,8 +293,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 10 - Add
Expand All @@ -296,8 +315,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
use_rescale=True,
rescale_scale=1,
rescale_shift=1,
Expand All @@ -316,13 +335,13 @@ def callback(
ifm2_zero_point=0,
ofm_scale=2.0,
ofm_zero_point=0,
ifm_channels=depth,
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128 * 2,
clip_max=127 * 2,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 12 - Add
Expand All @@ -343,8 +362,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
)

nr_x = rescale_w_offset
Expand All @@ -368,8 +387,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128 * 2,
clip_max=127 * 2,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 14, 19, 24 - Sub
Expand All @@ -388,6 +407,9 @@ def callback(
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 15, 20, 25 - Mul
Expand All @@ -407,8 +429,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128 * 2,
clip_max=127 * 2,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 16, 21, 26 - Mul
Expand All @@ -428,8 +450,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 17, 22, 27 - Add
Expand All @@ -448,6 +470,9 @@ def callback(
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 28 - Mul
Expand All @@ -468,8 +493,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128,
clip_max=127,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 29 - Mul
Expand All @@ -489,8 +514,8 @@ def callback(
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
clip_min=-128 * 2,
clip_max=127 * 2,
clip_min=quant_min,
clip_max=quant_max,
)

# PASS 30 - SHR
Expand Down
114 changes: 72 additions & 42 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,9 +3526,10 @@ def representative_dataset():
assert tuple(func_body.args[1].checked_type.shape) == (256,)


@pytest.mark.parametrize("ifm_shape", [(1, 12), (1, 12, 32)])
def test_tflite_softmax(ifm_shape):
def test_tflite_softmax():
np.random.seed(0)
dtype = "int8"
ifm_shape = (1, 12)

def create_tflite_graph():
@tf.function
Expand All @@ -3539,7 +3540,7 @@ def softmax(x):
# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
data = np.random.uniform(low=-1, high=2, size=tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
Expand All @@ -3554,51 +3555,71 @@ def representative_dataset():
def verify(ext_func):
out_op = ext_func.body
ops = []
# List of expected operations and their type if it exists
expected_ops = [
("reshape", None),
("reshape", None),
("contrib.ethosu.pooling", "MAX"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "SHR"),
("contrib.ethosu.pooling", "SUM"),
("contrib.ethosu.unary_elementwise", "CLZ"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "SHL"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "SHL"),
("contrib.ethosu.binary_elementwise", "ADD"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "ADD"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "ADD"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "ADD"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "ADD"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "MUL"),
("contrib.ethosu.binary_elementwise", "SUB"),
("contrib.ethosu.binary_elementwise", "SHR"),
("reshape", None),
# List of expected operations, their type and activation parameters if it exists
expected_ops_params = [
("reshape", None, [None, None, None, None, None, None]),
("reshape", None, [None, None, None, None, None, None]),
("contrib.ethosu.pooling", "MAX", [0.011756093241274357, -43, None, None, 0.0, -43]),
(
"contrib.ethosu.binary_elementwise",
"SUB",
[0.011756093241274357, -43, 0.0, -43, 1.0, 127],
),
("contrib.ethosu.binary_elementwise", "SHR", [1.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.pooling", "SUM", [0.0, 0, None, None, 0.0, -43]),
("contrib.ethosu.unary_elementwise", "CLZ", [0.0, 0, None, None, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "ADD", [0.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "ADD", [2.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 0.0, 0, 1.0, 0]),
("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 0]),
("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, -43]),
("contrib.ethosu.binary_elementwise", "SHR", [2.0, 0, 0.0, 0, 0.00390625, -128]),
("reshape", None, [None, None, None, None, None, None]),
]

def get_attr_value(op, attr_name):
if hasattr(op.attrs, attr_name):
return op.attrs[attr_name]
else:
return None

def get_op_type(op):
if hasattr(op.attrs, "pooling_type"):
return op.attrs.pooling_type
elif hasattr(op.attrs, "operator_type"):
return op.attrs.operator_type
return None

def get_activation_params(op):
activation_params = []
activation_params.append(get_attr_value(op, "ifm_scale"))
activation_params.append(get_attr_value(op, "ifm_zero_point"))
activation_params.append(get_attr_value(op, "ifm2_scale"))
activation_params.append(get_attr_value(op, "ifm2_zero_point"))
activation_params.append(get_attr_value(op, "ofm_scale"))
activation_params.append(get_attr_value(op, "ofm_zero_point"))
return activation_params

def _visit(stmt):
if isinstance(stmt, relay.expr.Call):
ops.append(stmt)
Expand All @@ -3616,9 +3637,18 @@ def _visit(stmt):
assert ofm.dtype == dtype

# check operations

ops = [(op.op.name, get_op_type(op)) for op in ops]
assert expected_ops == ops
for op, expected_op_params in zip(ops, expected_ops_params):
activation_params = get_activation_params(op)
expected_op_name, expected_op_type, expected_activation_params = expected_op_params
assert op.op.name == expected_op_name
assert expected_op_type == get_op_type(op)
for activation_param, expected_activation_param in zip(
activation_params, expected_activation_params
):
if isinstance(activation_param, float):
assert math.isclose(expected_activation_param, activation_param, abs_tol=1e-7)
else:
assert expected_activation_param == activation_param

softmax_pattern_table = [
(
Expand Down

0 comments on commit 28aead9

Please sign in to comment.