Skip to content

Commit

Permalink
[TIR][Transform] Clear buffer_map during MakeUnpackedAPI (apache#12891)
Browse files Browse the repository at this point in the history
* [TIR][Transform] Clear buffer_map during MakeUnpackedAPI

This mimics the behavior in `MakePackedAPI`, and is assumed to be the
case for some codegens.

* Remove read of buffer_map  in ethosu.tir_to_cs_translator

This previously relied on `MakeUnpackedAPI` preserving the
`PrimFunc::buffer_map`, even after it had been used for lowering.  It
now reads from the `BufferLoad` and `BufferStore` nodes to determine
buffer shapes.

* Added more documentation for MakePackedAPI/MakeUnpackedAPI
  • Loading branch information
Lunderberg authored Sep 27, 2022
1 parent a07a46e commit bec9f16
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
30 changes: 30 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,36 @@ def get_outer_loops(stmt, layout):
return None


def collect_buffer_map(stmt):
"""Collect a map of Var -> Buffer
Generate a map from a buffer's backing `tir.Var` to the
`tir.Buffer` object that uses it. If multiple such buffers exist,
return the first occurrence.
Parameters
----------
stmt : tvm.tir.Stmt
The statement to get the BufferLoads from.
Returns
-------
buffer_map : Dict[Var, Buffer]
The map from buffer var to the buffers that use it.
"""
buffer_map = {}

def _visit(node):
if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore)):
buf = node.buffer
if buf.data not in buffer_map:
buffer_map[buf.data] = buf

tvm.tir.stmt_functor.post_order_visit(stmt, _visit)

return buffer_map


def get_loads(stmt):
"""Get the BufferLoad statements.
Expand Down
37 changes: 26 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.backend.contrib.ethosu import vela_api
from tvm.relay.backend.contrib.ethosu.tir import spec
from tvm.relay.backend.contrib.ethosu.tir import utils as tir_utils


class BufferType(Enum):
Expand Down Expand Up @@ -254,26 +255,40 @@ def extract_param_base_addresses(mod, buffer_info, scratch_region_map) -> List[u
assert len(mod.functions.items()) == 1
primfunc = mod.functions.items()[0][1]

buffer_map = tir_utils.collect_buffer_map(primfunc.body)

base_addresses = list()
idx = 0

for param in primfunc.params:
# constants are pooled together and handled specially
# this will change after tir.allocate_const.
# For now, we are skipping generating buffer addresses here
if buffer_info[param].btype == BufferType.constant:
continue
buffer = primfunc.buffer_map[param]
dtype = buffer.dtype
element_size_bytes = np.iinfo(dtype).bits // 8
size_bytes = element_size_bytes * np.prod(list(buffer.shape))
base_addresses.append(
util.BaseAddress(
param.name.replace("-", "_"),
idx,
_get_region(buffer_info[param].btype, param, scratch_region_map),
size_bytes,

if param in buffer_map:
buffer = buffer_map[param]
dtype = buffer.dtype
element_size_bytes = np.iinfo(dtype).bits // 8
size_bytes = element_size_bytes * np.prod(list(buffer.shape))
base_addresses.append(
util.BaseAddress(
param.name.replace("-", "_"),
idx,
_get_region(buffer_info[param].btype, param, scratch_region_map),
size_bytes,
)
)
else:
base_addresses.append(
util.BaseAddress(
param.name.replace("-", "_"),
idx,
_get_region(buffer_info[param].btype, param, scratch_region_map),
0,
)
)
)
idx += 1

return base_addresses
Expand Down
30 changes: 30 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,26 @@ def LowerCustomDatatypes():
def MakePackedAPI():
"""Transform the PrimFuncs in the module to a packed func API.
Prior to this pass, the PrimFunc may have Buffer arguments defined
in the `PrimFuncNode::buffer_map`. This pass consumes the
`buffer_map`, using it to generate `TVMArgs` and `TVMRetValue*`
arguments that implement the `PackedFunc` API.
For static shapes, the `BufferNode::shape`, `BufferNode::strides`,
and `BufferNode::elem_offset` member variables are used to
generate runtime checks on the corresponding member variables in
the user-provided `DLTensor*` or `tvm.nd.array` argument. (e.g. A
PrimFunc that accepts a buffer of shape `[16,32]` validates that
the `DLTensor::shape` array is `[16,32]`.)
For dynamic Buffers, in which one or more of these `BufferNode` member
variables use `tir.Var` that are not defined by other PrimFunc
parameters, these are instead used to define the variables based on
the corresponding `DLTensor` members. (e.g. A PrimFunc that accepts a
buffer of shape `[tir.Var("n"), tir.Var("m")]`, when passed a
`DLTensor` of shape `[16,32]`, will define `n = 16` and `n=32`, based
on the argument's shape.
Returns
-------
fpass : tvm.transform.Pass
Expand All @@ -401,6 +421,16 @@ def MakePackedAPI():
def MakeUnpackedAPI():
"""Transform the PrimFuncs in the module to a C API compatible with internal calls.
Prior to this pass, the PrimFunc may have Buffer arguments defined in
the `PrimFuncNode::buffer_map`. This pass consumes the `buffer_map`,
using it to generate `T*` arguments (e.g. `float32*`) that can be
directly called by a C API.
For static shapes, no runtime validation is performed to confirm that
the argument buffer's shape matches the expected shape. For dynamic
shapes, `MakeUnpackedAPI` requires that the dynamic parameters be
passed as separate `tir.Var` parameters.
Returns
-------
fpass : tvm.transform.Pass
Expand Down
7 changes: 2 additions & 5 deletions src/tir/transforms/make_unpacked_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,13 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {

// Collect variables and buffers to map between
Array<Var> args;
Map<Var, Buffer> new_buffer_map;

for (const Var& param : func->params) {
// Ideally all func params should have Buffers defined in the buffer_map
// We should look to insert buffer_maps for all PrimFuncs that are returned
// to the core compiler.
if (func->buffer_map.find(param) != func->buffer_map.end()) {
args.push_back(func->buffer_map[param]->data);
// Rewiring the buffer_var to map to Buffers for low-level passes
// retain information about the buffer.
new_buffer_map.Set(func->buffer_map[param]->data, func->buffer_map[param]);
} else {
args.push_back(param);
}
Expand All @@ -82,7 +79,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc&& func) {
func_ptr->body = MergeNest(device_init, func_ptr->body);
func_ptr->params = args;
func_ptr->ret_type = PrimType(DataType::Int(32));
func_ptr->buffer_map = new_buffer_map;
func_ptr->buffer_map = Map<Var, Buffer>();

// return the function.
return std::move(func);
Expand Down

0 comments on commit bec9f16

Please sign in to comment.