From ab7d7de26cff38e3e0f17647d9ba31fafd2077b4 Mon Sep 17 00:00:00 2001 From: Alexey Yazev <113356454+Alexey-Yazev@users.noreply.github.com> Date: Fri, 18 Nov 2022 08:55:27 -0800 Subject: [PATCH] [microNPU] Fix Cascader code generation without StorageRewrite (#13365) There were extra memory allocations for buffers when parts of the buffer for the result were replaced with a buffer for the entire result (in ReplaceOperators pass) summing up we received a larger size in the number of parts --- python/tvm/relay/backend/contrib/ethosu/tir/compiler.py | 6 +----- python/tvm/relay/backend/contrib/ethosu/tir/passes.py | 5 +++++ .../contrib/test_ethosu/cascader/test_memory_reduction.py | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index 4133aff6ef51..2cf45170e4e3 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -94,12 +94,8 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod, const_dict = ethosu_passes.MergeConstants(const_dict)(mod) mod = ethosu_passes.CopyComputeReordering()(mod) - # When striping is enabled and if storage_rewrite is not run - # the striping results in incorrect code generation. This needs - # further investigation. Until such a time that is fixed, disable_storage_rewrite - # user directive will be overridden if striping is enabled. disable_storage_rewrite = curr_cfg.get("tir.disable_storage_rewrite", False) - if not disable_storage_rewrite or util.is_striping_enabled(): + if not disable_storage_rewrite: mod = tvm.tir.transform.StorageRewrite()(mod) mod = tvm.tir.transform.RemoveNoOp()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index e15d126dd969..f313ff720500 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -72,6 +72,7 @@ def ReplaceOperators(): producers_consumers = ProducersConsumers() replace_output_pointer = {} pointer_to_extents = {} + replaced_pointers = [] ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"]) @@ -136,9 +137,13 @@ def _replace_operator(stmt): stmt, producers_consumers ) if replace_pointer is not None: + # Allocate pointer only once + if replace_pointer in replaced_pointers: + is_allocator = False replace_output_pointer[output_pointer] = ReplaceInfo( replace_pointer, is_allocator ) + replaced_pointers.append(replace_pointer) # Make the extern call irb = tvm.tir.ir_builder.create() irb.emit(tvm.tir.call_extern("handle", op_name, *info)) diff --git a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py index e88282240510..99238fa59337 100644 --- a/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py +++ b/tests/python/contrib/test_ethosu/cascader/test_memory_reduction.py @@ -171,10 +171,10 @@ def tf_graph(x): @pytest.mark.parametrize( "accel_type, expected_ws_size_without_striping, expected_ws_size_with_striping", [ - ("ethos-u55-256", 180288, 15312), - ("ethos-u55-128", 180288, 15312), - ("ethos-u55-64", 180288, 14544), - ("ethos-u55-32", 180272, 14544), + ("ethos-u55-256", 180288, 15200), + ("ethos-u55-128", 180288, 15200), + ("ethos-u55-64", 180288, 14432), + ("ethos-u55-32", 180272, 14416), ], ) def test_depthwise2d_conv2d_pooling(