diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 41a6832c5953..0a6dcd146991 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -331,11 +331,15 @@ def _encode_weights(tir_extern_call, weights): def _new_buffer(old_buffer, new_value): """Create a new buffer and add the old buffer and its pointer to the rewriting maps.""" - new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) - pointer_to_buffer[new_buffer.data] = new_buffer + if old_buffer in rewrite_buffer: + new_buffer = rewrite_buffer[old_buffer] + else: + new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) + pointer_to_buffer[new_buffer.data] = new_buffer + buffer_to_const[new_buffer] = new_value + rewrite_buffer[old_buffer] = new_buffer rewrite_pointer[old_buffer.data] = new_buffer.data - buffer_to_const[new_buffer] = new_value def _visit_encode_pre(stmt): if isinstance(stmt, tvm.tir.Call): diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index de8a7f922390..7f5eeb1121af 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -108,6 +108,72 @@ def _get_func(): assert reference_const_sizes == test_const_sizes +# fmt: off +@tvm.script.ir_module +class RereadWeights: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") + buffer = T.match_buffer(placeholder_1, [304], dtype="uint8") + buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8") + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8") + # body + placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_re_read_weights(): + def _cascader(cached_func, const_dict, sch): + weights = cached_func.inputs[1] + bias = cached_func.inputs[2] + out = cached_func.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 2, 8) + cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + sch[cache_weights].compute_at(sch[out], co) + sch[cache_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_cascader) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + reference_mod = RereadWeights + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {1: 304, 2: 80} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + # fmt: off @tvm.script.ir_module class DirectReadOnly: