Skip to content

Commit

Permalink
[microNPU] Fix bug with re-reading in EncodeConstants (apache#9646)
Browse files Browse the repository at this point in the history
When a striping strategy that leads to weights
being re-read was deployed, the logic in EncodeConstants
failed. This adds a test for that case and fixed the
pass so it handles it correctly.

Change-Id: I6f54cdb7be69428e49c3b4208271cd3e6c192e5d
  • Loading branch information
mbaret authored and ylc committed Jan 13, 2022
1 parent fdbd3d9 commit 1ac86ac
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,15 @@ def _encode_weights(tir_extern_call, weights):
def _new_buffer(old_buffer, new_value):
"""Create a new buffer and add the old buffer and its pointer to the
rewriting maps."""
new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
pointer_to_buffer[new_buffer.data] = new_buffer
if old_buffer in rewrite_buffer:
new_buffer = rewrite_buffer[old_buffer]
else:
new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype))
pointer_to_buffer[new_buffer.data] = new_buffer
buffer_to_const[new_buffer] = new_value

rewrite_buffer[old_buffer] = new_buffer
rewrite_pointer[old_buffer.data] = new_buffer.data
buffer_to_const[new_buffer] = new_value

def _visit_encode_pre(stmt):
if isinstance(stmt, tvm.tir.Call):
Expand Down
66 changes: 66 additions & 0 deletions tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,72 @@ def _get_func():
assert reference_const_sizes == test_const_sizes


# fmt: off
@tvm.script.ir_module
class RereadWeights:
@T.prim_func
def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8")
buffer = T.match_buffer(placeholder_1, [304], dtype="uint8")
buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8")
ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8")
# body
placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True})
placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
__tvm_meta__ = None
# fmt: on


def test_re_read_weights():
def _cascader(cached_func, const_dict, sch):
weights = cached_func.inputs[1]
bias = cached_func.inputs[2]
out = cached_func.outputs[0]
conv_compute = Convolution2DCompute.from_output(out)
co = conv_compute.split(sch, 2, 8)
cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d])
cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d])
sch[cache_weights].compute_at(sch[out], co)
sch[cache_bias].compute_at(sch[out], co)

def _get_func():
ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8")
conv = make_ethosu_conv2d(
ifm,
32,
8,
(1, 1),
(0, 0),
(1, 1),
(1, 1),
)
func = relay.Function(relay.analysis.free_vars(conv), conv)
func = run_opt_pass(func, relay.transform.InferType())
return func

func = _get_func()
mod, consts = lower_to_tir(func, cascader=_cascader)
script = mod.script(show_meta=True)
test_mod = tvm.script.from_source(script)
reference_mod = RereadWeights
tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True)

reference_const_sizes = {1: 304, 2: 80}
test_const_sizes = {}
for key, value in consts.items():
test_const_sizes[key] = len(value)

assert reference_const_sizes == test_const_sizes


# fmt: off
@tvm.script.ir_module
class DirectReadOnly:
Expand Down

0 comments on commit 1ac86ac

Please sign in to comment.