Skip to content

Commit

Permalink
[microNPU] Integrate rolling buffers in Arm(R) Ethos(TM)-U
Browse files Browse the repository at this point in the history
Change-Id: Iede5e68981a063f6eb1e118433cc2c92e175af52
  • Loading branch information
Nicola Lancellotti committed Feb 23, 2022
1 parent faa2e6a commit 3050d19
Show file tree
Hide file tree
Showing 21 changed files with 427 additions and 173 deletions.
12 changes: 7 additions & 5 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
from tvm.relay.backend.contrib.ethosu.op import op_attrs
from tvm.relay.backend.contrib.ethosu import op

# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
CASCADER = copy_constants()


class OptimizeLUTs(ExprMutator):
"""A pass to merge an identity operator with a LUT based activation function with
Expand Down Expand Up @@ -332,11 +338,7 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
mod = LUTsOptimizer()(mod)
mod = LayoutOptimizer()(mod)
mod = relay.transform.InferType()(mod)
# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())
tir_mod, const_dict = lower_to_tir(mod["main"], CASCADER)

for param in const_dict.keys():
const_dict[param] = tvm.nd.array(const_dict[param])
Expand Down
24 changes: 11 additions & 13 deletions python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the binary_elementwise operators in TIR."""
from typing import Dict, Tuple
from typing import Tuple
import tvm
from .utils import get_outer_loops, get_op_attrs
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialActivation, SerialBinaryElementwise
from .producers_consumers import ProducersConsumers


def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
Expand All @@ -42,22 +43,17 @@ def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:


def get_binary_elementwise_params(
stmt: tvm.tir.AttrStmt,
producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
"""Get the parameters necessary to construct a call_extern for a binary_elementwise.
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a binary elementwise loop nest.
producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
Expand All @@ -84,9 +80,11 @@ def get_binary_elementwise_params(
input_pointer, input_pointer1 = input_pointer1, input_pointer
output_pointer = inner.buffer_var
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):

mod = tvm.tir.transform.Simplify()(mod)
mod = ethosu_passes.RemoveConcatenates()(mod)
mod = tvm.tir.transform.InjectRollingBuffer()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.Simplify()(mod)
Expand Down
17 changes: 8 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution


def get_conv2d_params(stmt, producers, consumers):
def get_conv2d_params(stmt, producers_consumers):
"""Get the parameters necessary to construct a call_extern for a 2D convolution.
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a convolution loop nest.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
Expand All @@ -62,8 +59,10 @@ def get_conv2d_params(stmt, producers, consumers):
input_pointer = loads[1].buffer_var
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down
22 changes: 10 additions & 12 deletions python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the depthwise convolution operators in TIR."""
from typing import Dict, Tuple
from typing import Tuple
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
Expand All @@ -27,25 +27,21 @@
SerialActivation,
Serial2DDepthwise,
)
from .producers_consumers import ProducersConsumers


def get_depthwise_conv2d_params(
stmt: tvm.tir.AttrStmt,
producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]:
"""Get the parameters necessary to construct a call_extern for a depthwise_conv2d.
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a depthwise loop nest.
producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
Expand All @@ -71,8 +67,10 @@ def get_depthwise_conv2d_params(
input_pointer = loads[1].buffer_var
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(
output_pointer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down
125 changes: 99 additions & 26 deletions python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,53 @@ def get_read_params(stmt):
strides = get_strides(inner.value.index, stride_vars)
base_address = get_base_address(inner.value.index)
data_type = inner.buffer_var.type_annotation.element_type.dtype

floor_mod_stmt = None

def _get_buffer_var(stmt):
if isinstance(stmt, tvm.tir.FloorMod):
nonlocal floor_mod_stmt
floor_mod_stmt = stmt

tvm.tir.stmt_functor.post_order_visit(inner.value, _get_buffer_var)
if floor_mod_stmt is None:
tile_height_0, tile_height_1 = h.extent, 0
tile_width_0 = w.extent
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = 0
tile_address_2 = 0
else:
var = floor_mod_stmt.a.a
tile_length = floor_mod_stmt.b - floor_mod_stmt.a.b
if var == h.loop_var:
tile_height_0 = tile_length
tile_height_1 = 0
tile_width_0 = w.extent
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = 0
tile_address_2 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
elif var == w.loop_var:
tile_height_0 = h.extent
tile_height_1 = h.extent
tile_width_0 = tile_length
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
tile_address_2 = 0
else:
assert False

return (
SerialFeatureMap(
data_type=data_type,
height=h.extent,
width=w.extent,
channels=c.extent,
tile_height_0=h.extent,
tile_height_1=0,
tile_width_0=w.extent,
tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address),
tile_address_1=0,
tile_address_2=0,
tile_height_0=tile_height_0,
tile_height_1=tile_height_1,
tile_width_0=tile_width_0,
tile_address_0=tile_address_0,
tile_address_1=tile_address_1,
tile_address_2=tile_address_2,
tile_address_3=0,
scale=attrs["scale"],
zero_point=attrs["zero_point"],
Expand Down Expand Up @@ -268,16 +303,16 @@ def get_write_params(stmt):
)


def get_ifm_params(pointer, producers):
def get_ifm_params(pointer, producers_consumers, stmt):
"""Get the parameters associated with the DMA capabilities for an IFM.
Parameters
----------
pointer : tvm.tir.Var
The pointer that the IFM DMA pipeline produces.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
Expand All @@ -287,31 +322,69 @@ def get_ifm_params(pointer, producers):
The serializable padding.
"""
pad = producers[pointer]
pad = producers_consumers.get_producer(pointer, stmt)
serial_padding, input_pointer, _ = get_pad_params(pad)
upscale = producers[input_pointer]
upscale = producers_consumers.get_producer(input_pointer, pad)
input_pointer, _ = get_upscale_params(upscale)
convert_to_nhwc = producers[input_pointer]
convert_to_nhwc = producers_consumers.get_producer(input_pointer, upscale)
in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc)
read = producers[input_pointer]
read = producers_consumers.get_producer(input_pointer, convert_to_nhwc)
serial_ifm, _, _ = get_read_params(read)
serial_ifm.channels = in_channels

floor_mod_stmt = None
for_stmt = None

def _get_buffer_var(stmt):
nonlocal for_stmt
nonlocal floor_mod_stmt
if isinstance(stmt, tvm.tir.For):
for_stmt = stmt
if isinstance(stmt, tvm.tir.FloorMod):
floor_mod_stmt = stmt

tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var)

if floor_mod_stmt is not None:
layout = get_op_attrs(read)[0]["layout"]
channels = serial_ifm.channels
if for_stmt.body.loop_var == floor_mod_stmt.a.a.a:
height_a = floor_mod_stmt.b - floor_mod_stmt.a.b
height_b = serial_ifm.height
serial_ifm.height = height_a + height_b
serial_ifm.tile_height_0 = serial_ifm.height
address = serial_ifm.tile_address_0
offset = (
height_a * (channels // 16 + 1) * serial_ifm.width * 16
if layout == "NHCWB16"
else height_a * serial_ifm.width * channels
)
serial_ifm.tile_address_0 = tvm.tir.Load(
address.dtype, address.buffer_var, address.index - offset
)
else:
width_a = floor_mod_stmt.b - floor_mod_stmt.a.b
width_b = serial_ifm.width
serial_ifm.width = width_a + width_b
serial_ifm.tile_width_0 = serial_ifm.width
address = serial_ifm.tile_address_0
offset = width_a * 16 if layout == "NHCWB16" else width_a * channels
serial_ifm.tile_address_0 = tvm.tir.Load(
address.dtype, address.buffer_var, address.index - offset
)
return serial_ifm, serial_padding


def get_ofm_params(pointer, consumers, producers):
def get_ofm_params(pointer, producers_consumers, stmt):
"""Get the parameters associated with the DMA capabilities for an OFM.
Parameters
----------
pointer : tvm.tir.Var
The pointer that the OFM DMA pipeline consumes.
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that consumes their values.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.
Returns
-------
Expand All @@ -323,14 +396,14 @@ def get_ofm_params(pointer, consumers, producers):
Whether this operator allocates its output.
"""
convert_to_nhcwb16 = consumers[pointer]
convert_to_nhcwb16 = producers_consumers.get_consumer(pointer, stmt)
out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16)
write = consumers[output_pointer]
write = producers_consumers.get_consumer(output_pointer, convert_to_nhcwb16)
serial_ofm, _, output_pointer = get_write_params(write)
is_allocator = True
if output_pointer not in producers:
is_allocator = False
elif producers[output_pointer] != write:

producer = producers_consumers.get_producer(output_pointer, write)
if producer is None or producer != write:
is_allocator = False
serial_ofm.channels = out_channels
return serial_ofm, output_pointer, is_allocator
Loading

0 comments on commit 3050d19

Please sign in to comment.