Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address review comments on Arm(R) Ethos(TM)-U PR 3/6 #9159

Merged
merged 3 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler"""
"""The integration of the Arm(R) Ethos(TM)-U NPU TIR compiler."""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
Expand All @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
"""Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.

The resulting TIR module will contain a single function
that comprises of a sequence of tir.extern_calls to NPU
that consists of a sequence of tir.extern_calls to NPU
operations.

Parameters
Expand Down Expand Up @@ -96,20 +96,20 @@ def lower_ethosu(sch, args, const_dict, name="main"):


def lower_to_te(prim_func):
"""Lower a Relay primitive function to a Tensor Expression graph.
"""Lower a Relay primitive function to a Tensor Expression in an unscheduled CachedFunc.

Parameters
----------
prim_func : tvm.relay.Function
The Relay function to lowerethosu_runtime([]).
The Relay function to lower.

Returns
-------
out : TEGraph
The lowered Tensor Expression graph.
out : CachedFunc
The lowered Tensor Expression as part of a CachedFunc.

"""
f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE")
f = tvm._ffi.get_global_func("relay.backend.LowerToTE")
return f(prim_func)


Expand Down Expand Up @@ -193,7 +193,7 @@ def lower_to_tir(func, cascader=None):
func, consts = extract_constants(func)
mod = tvm.IRModule.from_expr(func)
func = relay.transform.InferType()(mod)["main"]
te_graph = lower_to_te(func)
s = schedule(te_graph, consts, cascader)
mod, consts = lower_ethosu(s, te_graph, consts)
cached_func = lower_to_te(func)
s = schedule(cached_func, consts, cascader)
mod, consts = lower_ethosu(s, cached_func, consts)
return mod, consts
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the convolution operators in TIR."""
"""Extract parameters from the convolution operators in TIR."""
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the DMA operators in TIR."""
"""Extract parameters from the DMA operators in TIR."""
import tvm
from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs
from .spec import SerialFeatureMap, SerialPadding
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler"""
"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler."""
import numpy as np # type: ignore

import tvm
Expand Down Expand Up @@ -301,7 +301,7 @@ def EncodeConstants(const_dict):
pointer_to_buffer = {}
rewrite_buffer = {}
rewrite_pointer = {}
accel_type = vela_api.get_target_accel_type() # type: ignore
accel_config = vela_api.get_accelerator_config()

def _align_scale_bias(tir_extern_call, bias):
"""Align the scale_bias to 16 bytes."""
Expand All @@ -316,7 +316,7 @@ def _align_scale_bias(tir_extern_call, bias):

def _encode_weights(tir_extern_call, weights):
"""Encode the weights for a TIR extern call."""
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type)
value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_config)
value = np.frombuffer(value_bytes, dtype="uint8")
return value

Expand Down
36 changes: 18 additions & 18 deletions python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Different schedulers for Arm(R) Ethos(TM)-U NPU"""
"""Scheduling for Arm(R) Ethos(TM)-U NPU."""
import tvm


def schedule(te_graph, const_dict, cascader=None):
"""Schedule a TE graph for NPU compilation.
def schedule(cached_func, const_dict, cascader=None):
"""Schedule a CachedFunc for NPU compilation.

Parameters
----------
te_graph
The TE graph to schedule.
cached_func : CachedFunc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up we're trying to move both the TECompiler and CachedFunc structures to be internal only. I think the idea is all the cross-reference stuff accumulated in CachedFunc would be captured as 'official' attributes on the call or defn. So obviously we'll need to include this in that refactor.

The CachedFunc to schedule.
const_dict : dict of int to numpy.ndarray
The constant dictionary.
cascader : callable, optional
Expand All @@ -38,10 +38,10 @@ def schedule(te_graph, const_dict, cascader=None):
The completed schedule for the graph.

"""
s = tvm.te.create_schedule([t.op for t in te_graph.outputs])
s = tvm.te.create_schedule([t.op for t in cached_func.outputs])
if cascader:
cascader(te_graph, const_dict, s)
inline_no_ops(te_graph, s)
cascader(cached_func, const_dict, s)
inline_no_ops(cached_func, s)
schedule_pragmas(s)
schedule_cache_reads(s)
return s
Expand Down Expand Up @@ -96,7 +96,7 @@ def total_cascader(stripe_size):

"""

def _cascader(te_graph, const_dict, sch):
def _cascader(cached_func, const_dict, sch):
scheduled = set()

def _visit(tensor, stage, ax):
Expand All @@ -106,8 +106,8 @@ def _visit(tensor, stage, ax):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor, stage, ax)

assert len(te_graph.outputs) == 1
out = te_graph.outputs[0]
assert len(cached_func.outputs) == 1
out = cached_func.outputs[0]
oi, _ = tile_nd(sch, out, stripe_size)
for ax in oi:
sch[out].unroll(ax)
Expand All @@ -126,22 +126,22 @@ def copy_constants():
The planning function.
"""

def _planner(te_graph, const_dict, sch):
def _planner(cached_func, const_dict, sch):
planned = set() # type: ignore

def _visit(tensor, reader):
if tensor is not planned:
planned.add(tensor)
if isinstance(tensor.op, tvm.te.PlaceholderOp):
index = list(te_graph.inputs).index(tensor)
index = list(cached_func.inputs).index(tensor)
if index in const_dict:
sch.cache_read(tensor, "global", [reader])

elif isinstance(tensor.op, tvm.te.ComputeOp):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor, tensor)

for output_tensor in te_graph.outputs:
for output_tensor in cached_func.outputs:
_visit(output_tensor, None)

return _planner
Expand Down Expand Up @@ -216,16 +216,16 @@ def _detect_cache_read(stage):
stage.pragma(fax, "op", "ethosu_copy")


def inline_no_ops(te_graph, sch):
def inline_no_ops(cached_func, sch):
"""Inline 'no-ops' - operations that in principle do nothing.

Modifies the schedule in-place. For now we inline reshape and
strided slice - more could be added.

Parameters
----------
te_graph
The TE graph.
cached_func : CachedFunc
The cached func.
sch : tvm.te.Schedule
The schedule.

Expand All @@ -241,7 +241,7 @@ def _visit(tensor):
for input_tensor in tensor.op.input_tensors:
_visit(input_tensor)

for out in te_graph.outputs:
for out in cached_func.outputs:
_visit(out)


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the transform operators in TIR."""
"""Extract parameters from the transform operators in TIR."""
import tvm
from .spec import SerialCopy
from .utils import get_base_address, get_op_attrs
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility functions used by the TIR compiler"""
"""Helper utility functions used by the NPU TIR compiler"""
import tvm
from tvm import arith

Expand Down
Loading