Skip to content

Commit

Permalink
Fix tests and address comments
Browse files Browse the repository at this point in the history
Change-Id: I57c9a4ca77b82b6b79d648376be374b7f155a297
  • Loading branch information
Nicola Lancellotti committed Feb 25, 2022
1 parent 76f32e3 commit 066026b
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 38 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
CASCADER = copy_constants()
SCHEDULER = copy_constants()


class OptimizeLUTs(ExprMutator):
Expand Down Expand Up @@ -338,7 +338,7 @@ def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
mod = LUTsOptimizer()(mod)
mod = LayoutOptimizer()(mod)
mod = relay.transform.InferType()(mod)
tir_mod, const_dict = lower_to_tir(mod["main"], CASCADER)
tir_mod, const_dict = lower_to_tir(mod["main"], SCHEDULER)

for param in const_dict.keys():
const_dict[param] = tvm.nd.array(const_dict[param])
Expand Down
45 changes: 22 additions & 23 deletions python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,39 +193,38 @@ def get_read_params(stmt):
base_address = get_base_address(inner.value.index)
data_type = inner.buffer_var.type_annotation.element_type.dtype

floor_mod_stmt = None
floor_mod_var = None
tile_length = None

def _get_buffer_var(stmt):
if isinstance(stmt, tvm.tir.FloorMod):
nonlocal floor_mod_stmt
floor_mod_stmt = stmt
nonlocal floor_mod_var
floor_mod_var = stmt.a.a
nonlocal tile_length
tile_length = stmt.b - stmt.a.b

tvm.tir.stmt_functor.post_order_visit(inner.value, _get_buffer_var)
if floor_mod_stmt is None:
tile_height_0, tile_height_1 = h.extent, 0

if floor_mod_var == h.loop_var:
tile_height_0 = tile_length
tile_height_1 = 0
tile_width_0 = w.extent
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = 0
tile_address_2 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
elif floor_mod_var == w.loop_var:
tile_height_0 = h.extent
tile_height_1 = h.extent
tile_width_0 = tile_length
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
tile_address_2 = 0
else:
var = floor_mod_stmt.a.a
tile_length = floor_mod_stmt.b - floor_mod_stmt.a.b
if var == h.loop_var:
tile_height_0 = tile_length
tile_height_1 = 0
tile_width_0 = w.extent
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = 0
tile_address_2 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
elif var == w.loop_var:
tile_height_0 = h.extent
tile_height_1 = h.extent
tile_width_0 = tile_length
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = tvm.tir.Load(data_type, inner.value.buffer_var, 0)
tile_address_2 = 0
else:
assert False
tile_height_0, tile_height_1 = h.extent, 0
tile_width_0 = w.extent
tile_address_0 = tvm.tir.Load(data_type, inner.value.buffer_var, base_address)
tile_address_1 = 0
tile_address_2 = 0

return (
SerialFeatureMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ def replace_npu_fm_with_address(npu_fm):
np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8
)
npu_fm.tiles.addresses[0] = address + int(index)
npu_fm.tiles.addresses[1] = address
npu_fm.tiles.addresses[2] = address
npu_fm.region = region
return npu_fm

Expand Down
16 changes: 8 additions & 8 deletions tests/python/contrib/test_ethosu/test_encode_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def _planner(cached_func, const_dict, sch):
out = cached_func.outputs[0]
conv_compute = OperatorCompute.from_output(out)
co = conv_compute.split(sch, 3, 2)
cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d])
cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d])
cache_weights = sch.cache_read(weights, "global", [conv_compute.op])
cache_bias = sch.cache_read(bias, "global", [conv_compute.op])
sch[cache_weights].compute_at(sch[out], co)
sch[cache_bias].compute_at(sch[out], co)

Expand Down Expand Up @@ -117,10 +117,10 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[
placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True})
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle"))
__tvm_meta__ = None
# fmt: on

Expand All @@ -132,8 +132,8 @@ def _cascader(cached_func, const_dict, sch):
out = cached_func.outputs[0]
conv_compute = OperatorCompute.from_output(out)
co = conv_compute.split(sch, 2, 8)
cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d])
cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d])
cache_weights = sch.cache_read(weights, "global", [conv_compute.op])
cache_bias = sch.cache_read(bias, "global", [conv_compute.op])
sch[cache_weights].compute_at(sch[out], co)
sch[cache_bias].compute_at(sch[out], co)

Expand Down Expand Up @@ -266,8 +266,8 @@ def _planner(cached_func, const_dict, sch):
out = cached_func.outputs[0]
conv_compute = OperatorCompute.from_output(out)
co = conv_compute.split(sch, 3, 2)
cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d])
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d])
cache_weight = sch.cache_read(weight, "global", [conv_compute.op])
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.op])
sch[cache_weight].compute_at(sch[out], co)
sch[cache_scale_bias].compute_at(sch[out], co)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_ethosu/test_lower_to_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_ethosu_conv2d():
assert len(lowered.outputs) == 1
assert len(lowered.inputs) == 4
conv2d_compute = OperatorCompute.from_output(lowered.outputs[0])
assert conv2d_compute.conv2d.name == "ethosu_conv2d"
assert conv2d_compute.op.name == "ethosu_conv2d"
input_shapes = set()
for inp in lowered.inputs:
input_shapes.add(tuple([x.value for x in inp.shape]))
Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_ethosu/test_replace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def _cascader(cached_func, const_dict, sch):
out = cached_func.outputs[0]
conv_compute = OperatorCompute.from_output(out)
co = conv_compute.split(sch, 3, 10)
cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d])
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d])
cache_weight = sch.cache_read(weight, "global", [conv_compute.op])
cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.op])
sch[cache_weight].compute_at(sch[out], co)
sch[cache_scale_bias].compute_at(sch[out], co)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/contrib/test_ethosu/test_rolling_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _cascader(cached_func, const_dict, sch):
pool_a_compute.compute_at(sch, stage=sch[pool_b_out], axis=outer)
pool_a_compute.rolling_buffer(sch)

codegen.CASCADER = _cascader
codegen.SCHEDULER = _cascader
_compare_tvm_with_tflite(tf_model, [ifm_shape], accel_type)


Expand Down Expand Up @@ -95,7 +95,7 @@ def _cascader(cached_func, const_dict, sch):
pool_a_compute.compute_at(sch, stage=sch[pool_b_out], axis=outer)
pool_a_compute.rolling_buffer(sch)

codegen.CASCADER = _cascader
codegen.SCHEDULER = _cascader
_compare_tvm_with_tflite(tf_model, [ifm_shape], accel_type)


Expand Down

0 comments on commit 066026b

Please sign in to comment.