Skip to content

Commit

Permalink
Address review comments on Arm(R) Ethos(TM)-U PR 3/6 (#9159)
Browse files Browse the repository at this point in the history
* Address review comments on Arm(R) Ethos(TM)-U PR 3/6

Change-Id: I22961885a503be31f6a72622ae0b5f874cc6f463

* Fix rebasing error

Change-Id: I3e2fde786096ea331fcb366080fa779ec4ea4a5d

* Fix more rebasing problems

Change-Id: I1026e3ccee33a3fdec9ebbf6456bae244ad4f1d5
  • Loading branch information
mbaret authored Oct 12, 2021
1 parent 9f27be6 commit 4f6b478
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 404 deletions.
20 changes: 10 additions & 10 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler"""
"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler."""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
Expand All @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
"""Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.
The resulting TIR module will contain a single function
that comprises of a sequence of tir.extern_calls to NPU
that consists of a sequence of tir.extern_calls to NPU
operations.
Parameters
Expand Down Expand Up @@ -96,20 +96,20 @@ def lower_ethosu(sch, args, const_dict, name="main"):


def lower_to_te(prim_func):
"""Lower a Relay primitive function to a Tensor Expression graph.
"""Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc.
Parameters
----------
prim_func : tvm.relay.Function
The Relay function to lowerethosu_runtime([]).
The Relay function to lower.
Returns
-------
out : TEGraph
The lowered Tensor Expression graph.
out : CachedFunc
The lowered Tensor Expression as part of a CachedFunc.
"""
f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE")
f = tvm._ffi.get_global_func("relay.backend.LowerToTE")
return f(prim_func)


Expand Down Expand Up @@ -193,7 +193,7 @@ def lower_to_tir(func, cascader=None):
func, consts = extract_constants(func)
mod = tvm.IRModule.from_expr(func)
func = relay.transform.InferType()(mod)["main"]
te_graph = lower_to_te(func)
s = schedule(te_graph, consts, cascader)
mod, consts = lower_ethosu(s, te_graph, consts)
cached_func = lower_to_te(func)
s = schedule(cached_func, consts, cascader)
mod, consts = lower_ethosu(s, cached_func, consts)
return mod, consts
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the convolution operators in TIR."""
"""Extract parameters from the convolution operators in TIR."""
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the DMA operators in TIR."""
"""Extract parameters from the DMA operators in TIR."""
import tvm
from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs
from .spec import SerialFeatureMap, SerialPadding
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler"""
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
import numpy as np # type: ignore

import tvm
Expand Down Expand Up @@ -301,7 +301,7 @@ def EncodeConstants(const_dict):
pointer_to_buffer = {}
rewrite_buffer = {}
rewrite_pointer = {}
accel_type = vela_api.get_target_accel_type() # type: ignore
accel_config = vela_api.get_accelerator_config()

def _align_scale_bias(tir_extern_call, bias):
"""Align the scale_bias to 16 bytes."""
Expand All @@ -316,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias):

def _encode_weights(tir_extern_call, weights):
"""Encode the weights for a TIR extern call."""
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type)
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config)
value = np.frombuffer(value_bytes, dtype="uint8")
return value

Expand Down
36 changes: 18 additions & 18 deletions python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Different schedulers for Arm(R) Ethos(TM)-U NPU"""
"""Scheduling for Arm(R) Ethos(TM)-U NPU."""
import tvm


def schedule(te_graph, const_dict, cascader=None):
"""Schedule a TE graph for NPU compilation.
def schedule(cached_func, const_dict, cascader=None):
"""Schedule a CachedFunc for NPU compilation.
Parameters
----------
te_graph
The TE graph to schedule.
cached_func : CachedFunc
The CachedFunc to schedule.
const_dict : dict of int to numpy.ndarray
The constant dictionary.
cascader : callable, optional
Expand All @@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None):
The completed schedule for the graph.
"""
s = tvm.te.create_schedule([t.op for t in te_graph.outputs])
s = tvm.te.create_schedule([t.op for t in cached_func.outputs])
if cascader:
cascader(te_graph, const_dict, s)
inline_no_ops(te_graph, s)
cascader(cached_func, const_dict, s)
inline_no_ops(cached_func, s)
schedule_pragmas(s)
schedule_cache_reads(s)
return s
Expand Down Expand Up @@ -96,7 +96,7 @@ def total_cascader(stripe_size):
"""

def _cascader(te_graph, const_dict, sch):
def _cascader(cached_func, const_dict, sch):
scheduled = set()

def _visit(tensor, stage, ax):
Expand All @@ -106,8 +106,8 @@ def _visit(tensor, stage, ax):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor, stage, ax)

assert len(te_graph.outputs) == 1
out = te_graph.outputs[0]
assert len(cached_func.outputs) == 1
out = cached_func.outputs[0]
oi, _ = tile_nd(sch, out, stripe_size)
for ax in oi:
sch[out].unroll(ax)
Expand All @@ -126,22 +126,22 @@ def copy_constants():
The planning function.
"""

def _planner(te_graph, const_dict, sch):
def _planner(cached_func, const_dict, sch):
planned = set() # type: ignore

def _visit(tensor, reader):
if tensor is not planned:
planned.add(tensor)
if isinstance(tensor.op, tvm.te.PlaceholderOp):
index = list(te_graph.inputs).index(tensor)
index = list(cached_func.inputs).index(tensor)
if index in const_dict:
sch.cache_read(tensor, "global", [reader])

elif isinstance(tensor.op, tvm.te.ComputeOp):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor, tensor)

for output_tensor in te_graph.outputs:
for output_tensor in cached_func.outputs:
_visit(output_tensor, None)

return _planner
Expand Down Expand Up @@ -216,16 +216,16 @@ def _detect_cache_read(stage):
stage.pragma(fax, "op", "ethosu_copy")


def inline_no_ops(te_graph, sch):
def inline_no_ops(cached_func, sch):
"""Inline 'no-ops' - operations that in principle do nothing.
Modifies the schedule in-place. For now we inline reshape and
strided slice - more could be added.
Parameters
----------
te_graph
The TE graph.
cached_func : CachedFunc
The cached func.
sch : tvm.te.Schedule
The schedule.
Expand All @@ -241,7 +241,7 @@ def _visit(tensor):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor)

for out in te_graph.outputs:
for out in cached_func.outputs:
_visit(out)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the transform operators in TIR."""
"""Extract parameters from the transform operators in TIR."""
import tvm
from .spec import SerialCopy
from .utils import get_base_address, get_op_attrs
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility functions used by the TIR compiler"""
"""Helper utility functions used by the NPU TIR compiler"""
import tvm
from tvm import arith

Expand Down
Loading

0 comments on commit 4f6b478

Please sign in to comment.