diff --git a/apps/microtvm/cmsisnn/requirements.txt b/apps/microtvm/cmsisnn/requirements.txt index 1c99bd49a92e2..896d6a8d723c7 100644 --- a/apps/microtvm/cmsisnn/requirements.txt +++ b/apps/microtvm/cmsisnn/requirements.txt @@ -7,11 +7,11 @@ cloudpickle==2.0.0 \ decorator==5.1.0 \ --hash=sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374 \ --hash=sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7 -ethos-u-vela==3.5.0 \ - --hash=sha256:e56c2f62e06439f45d07f2e6f41fd133a46fb7b6a2e0e6d3baf7ec1d947baca1 -flatbuffers==1.12 \ - --hash=sha256:63bb9a722d5e373701913e226135b28a6f6ac200d5cc7b4d919fa38d73b44610 \ - --hash=sha256:9e9ef47fa92625c4721036e7c4124182668dc6021d9e7c73704edd395648deb9 +ethos-u-vela==3.7.0 \ + --hash=sha256:314b761e171d19bf03e141684f9a371af7bf830f739b9b2f90b5f303a7fb1203 +flatbuffers==2.0.7 \ + --hash=sha256:0ae7d69c5b82bf41962ca5fde9cc43033bc9501311d975fd5a25e8a7d29c1245 \ + --hash=sha256:71e135d533be527192819aaab757c5e3d109cb10fbb01e687f6bdb7a61ad39d1 lxml==4.6.3 \ --hash=sha256:079f3ae844f38982d156efce585bc540c16a926d4436712cf4baee0cce487a3d \ --hash=sha256:0fbcf5565ac01dff87cbfc0ff323515c823081c5777a9fc7703ff58388c258c3 \ diff --git a/apps/microtvm/ethosu/requirements.txt b/apps/microtvm/ethosu/requirements.txt index d8a7fa7bd9010..17bc9bb46490a 100644 --- a/apps/microtvm/ethosu/requirements.txt +++ b/apps/microtvm/ethosu/requirements.txt @@ -7,11 +7,11 @@ cloudpickle==2.0.0 \ decorator==5.1.0 \ --hash=sha256:7b12e7c3c6ab203a29e157335e9122cb03de9ab7264b137594103fd4a683b374 \ --hash=sha256:e59913af105b9860aa2c8d3272d9de5a56a4e608db9a2f167a8480b323d529a7 -ethos-u-vela==3.5.0 \ - --hash=sha256:e56c2f62e06439f45d07f2e6f41fd133a46fb7b6a2e0e6d3baf7ec1d947baca1 -flatbuffers==1.12 \ - --hash=sha256:63bb9a722d5e373701913e226135b28a6f6ac200d5cc7b4d919fa38d73b44610 \ - --hash=sha256:9e9ef47fa92625c4721036e7c4124182668dc6021d9e7c73704edd395648deb9 +ethos-u-vela==3.7.0 \ + --hash=sha256:314b761e171d19bf03e141684f9a371af7bf830f739b9b2f90b5f303a7fb1203 +flatbuffers==2.0.7 \ + --hash=sha256:0ae7d69c5b82bf41962ca5fde9cc43033bc9501311d975fd5a25e8a7d29c1245 \ + --hash=sha256:71e135d533be527192819aaab757c5e3d109cb10fbb01e687f6bdb7a61ad39d1 lxml==4.6.3 \ --hash=sha256:079f3ae844f38982d156efce585bc540c16a926d4436712cf4baee0cce487a3d \ --hash=sha256:0fbcf5565ac01dff87cbfc0ff323515c823081c5777a9fc7703ff58388c258c3 \ diff --git a/conda/recipe/bld.bat b/conda/recipe/bld.bat index 6af4a9bacf637..f8988b1357937 100644 --- a/conda/recipe/bld.bat +++ b/conda/recipe/bld.bat @@ -27,6 +27,7 @@ cmake ^ -DUSE_LLVM=ON ^ -DUSE_RPC=ON ^ -DUSE_CPP_RPC=ON ^ + -DUSE_MICRO=ON ^ -DUSE_SORT=ON ^ -DUSE_RANDOM=ON ^ -DUSE_PROFILER=ON ^ diff --git a/conda/recipe/build.sh b/conda/recipe/build.sh index 0131fd65a48e8..3422c4d8f13b8 100755 --- a/conda/recipe/build.sh +++ b/conda/recipe/build.sh @@ -49,6 +49,7 @@ cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \ -DCMAKE_BUILD_TYPE=Release \ -DUSE_RPC=ON \ -DUSE_CPP_RPC=OFF \ + -DUSE_MICRO=ON \ -DUSE_SORT=ON \ -DUSE_RANDOM=ON \ -DUSE_PROFILER=ON \ diff --git a/docker/install/ubuntu_install_vela.sh b/docker/install/ubuntu_install_vela.sh index 8d43a4d6e112b..69c461547ad03 100755 --- a/docker/install/ubuntu_install_vela.sh +++ b/docker/install/ubuntu_install_vela.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip3 install ethos-u-vela==3.5.0 +pip3 install ethos-u-vela==3.7.0 diff --git a/gallery/how_to/work_with_microtvm/micro_ethosu.py b/gallery/how_to/work_with_microtvm/micro_ethosu.py index e6f47321c8129..ea1e9d754287a 100644 --- a/gallery/how_to/work_with_microtvm/micro_ethosu.py +++ b/gallery/how_to/work_with_microtvm/micro_ethosu.py @@ -84,8 +84,8 @@ # attrs==21.2.0 # cloudpickle==2.0.0 # decorator==5.1.0 -# ethos-u-vela==3.5.0 -# flatbuffers==1.12 +# ethos-u-vela==3.7.0 +# flatbuffers==2.0.7 # lxml==4.6.3 # nose==1.3.7 # numpy==1.19.5 diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 0aaa8b3e8aece..d4f537ff31691 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -337,11 +337,17 @@ TVM_DLL Pass CombineContextCall(); TVM_DLL Pass NarrowDataType(int target_bits); /*! - * \brief Legalize bf16 typed Ops. Add a cast to fp32 + * \brief Legalize bf16 compute Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. * \return The pass. */ -TVM_DLL Pass BF16Legalize(); +TVM_DLL Pass BF16ComputeLegalize(); + +/*! + * \brief Legalize bf16 storage types to u16. + * \return The pass. + */ +TVM_DLL Pass BF16StorageLegalize(); /*! * \brief Rewrite the pointer content type of arguments, diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 49b50019f04d3..132992c57dc72 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -310,11 +310,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { return compute( - x->shape, - [&](const Array& i) { - return tvm::tir::Call(type, tvm::tir::builtin::reinterpret(), {x(i)}); - }, - name, tag); + x->shape, [&](const Array& i) { return reinterpret(type, x(i)); }, name, tag); } /*! diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 09fb57ee94c03..1a55dccd11308 100644 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -246,7 +246,7 @@ "docutils", "<0.17", ), # Work around https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - ("ethos-u-vela", "==3.5.0"), + ("ethos-u-vela", "==3.7.0"), ("future", None), ("h5py", "==2.10.0"), ("image", None), diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 6e61e762ee212..c42974593a6b5 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -20,8 +20,12 @@ """ import logging import os.path +import re +import itertools +from copy import deepcopy from typing import Any, Optional, Dict, List, Union, Callable, Sequence from pathlib import Path +from collections import defaultdict import tvm from tvm import autotvm, auto_scheduler @@ -31,6 +35,8 @@ from tvm.ir.memory_pools import WorkspaceMemoryPools from tvm.target import Target from tvm.relay.backend import Executor, Runtime +from tvm.relay.analysis.operations_distribution import analyze_operations_distribution +from tvm.relay.transform.suffixes import tag_suffixes from . import composite_target, frontends, TVMCException from .model import TVMCModel, TVMCPackage @@ -69,6 +75,16 @@ def add_compile_parser(subparsers, _, json_params): default="", help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.", ) + parser.add_argument( + "--dump-offloads", + default="", + help="output a mapping of which operations of the initial Relay " + "will be transferred to which backend, indicating the composite " + "that includes those operations, " + "e.g. '--dump-offloads -' to dump to the console, " + "e.g. '--dump-offloads ' to dump to the file. " + "If not presented, no output is done. ", + ) parser.add_argument( "--model-format", choices=frontends.get_frontend_names(), @@ -171,6 +187,8 @@ def drive_compile(args): dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None + dump_offloads = args.dump_offloads if args.dump_offloads else "" + additional_targets = reconstruct_target_args(args) workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets) transform_args = parse_graph_transform_args(args) @@ -187,6 +205,7 @@ def drive_compile(args): cross_options=args.cross_compiler_options, output_format=args.output_format, dump_code=dump_code, + dump_offloads=dump_offloads, target_host=None, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, @@ -213,6 +232,7 @@ def compile_model( cross_options: Optional[str] = None, output_format: str = "so", dump_code: Optional[List[str]] = None, + dump_offloads: str = "", target_host: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, @@ -259,6 +279,10 @@ def compile_model( dump_code : list[str], optional Dump the generated code for the specified source types, on the requested target. Choose from: ["asm", "ll", "tir", "relay"]. + dump_offloads : str + Dump the information about the partition of input model's layers by external codegen. + Can be '' to not dump at all, '-' to dump to the console + or '' to dump to the specified file. target_host : str, optional The target of the host machine if host-side code needs to be generated. @@ -313,6 +337,13 @@ def compile_model( if "tir" in dump_code: config, dumps = add_tir_to_dumps(config, dumps) + initial_relay = None + if dump_offloads != "": + # add suffixes to the span field for calls in Relay + mod = tag_suffixes(mod) + # remember initial Relay + initial_relay = deepcopy(mod) + tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) @@ -337,6 +368,10 @@ def compile_model( for partition_function, opts in zip(partition_functions, partition_opts): mod = partition_function(mod, params, mod_name=mod_name, **opts) + if initial_relay: + # dump which operations are offloaded to which backend + dump_operation_offloads(mod, initial_relay, dump_offloads) + if tuning_records and os.path.exists(tuning_records): logger.debug("tuning records file provided: %s", tuning_records) @@ -496,3 +531,141 @@ def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): dump_name = module_name + "." + dump_format with open(Path(dump_root, dump_name), "w") as f: f.write(dumps[dump_format]) + + +def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule, dump_path: str): + """This helper function forms a line-by-line output of the initial Relay lines, + indicating which operations are ported to which target, + and indicating the composite that includes those operations; + the 'generic' target refers to operations uploaded to the host, e.g + 'target1 <- target1.qnn_conv2d' + 'target1 <- %0 = qnn.conv2d(%tfl.quantize, %v_param_1, ...' + 'target1 <- %1 = nn.bias_add(%0, %v_param_2, axis=3);' + 'target1 <- %2 = qnn.requantize(%1, meta[relay.Constant]...' + 'target2 <- target2.reshape' + 'target2 <- %3 = reshape(%2, newshape=[1, 1001]);' + 'generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1]...' + + Parameters + ---------- + mod : tvm.ir.IRModule + The partitioned IRModule with external global functions. + initial_mod : tvm.ir.IRModule + The initial IRModule that gets generated from a relay frontend. + dump_path: str + Value of the "dump_offloads" compiler atribute. + Could be dash ("-") or file path or empty string for + printing to console, file or doing nothing respectively. + """ + print_to_console = dump_path == "-" + save_to_file = all([dump_path != "-", dump_path != ""]) + + if print_to_console or save_to_file: + + operations_distribution = analyze_operations_distribution(mod) + + def annotate_f(x): + ret = "" + if isinstance(x, relay.Call): + # if there is no x.span.source_name.name in operations_distribution, + # this could mean that the span was not copied during the application of passes + # to the Relay, in which case we can not associate the initial Relay string + # with the resulting Relay call + source_name = x.span.source_name.name + if source_name in operations_distribution: + compiler_name, op_name, func_id = operations_distribution[source_name] + ret = ( + f", compiler_name: {compiler_name}, op_name: {op_name}, " + f"func_id: {func_id}" + ) + else: + ret = ", compiler_name: unknown, op_name: unknown, func_id: unknown" + return ret + + initial_relay_astext = initial_mod.astext(show_meta_data=False, annotate=annotate_f).split( + "\n" + ) + + # funcs_list is a list of internal composite/function IDs + # generated by analyze_operations_distribution(). + # funcs_list helps keep the order of lines from the initial Relay. + funcs_list = [] + + # target_statistic is a mapping of the target name to the + # number of initial Relay calls offloaded on the target + target_statistic = defaultdict(int) + + # funcs_dict is a mapping of the generated analyze_operations_distribution + # internal composite/function IDs to a list, where: + # 1st element is + # (1a): target name - it could be "generic" or "unknown" or + # (1b): specific target name, like "ethos-u" or "cmsis-nn" + # 2nd element is + # (2a): corresponding initial Relay line for the case (1a) or + # (2b): the name of the target composite functon in the other case (1b) + # 3rd element or subsequent ones are presented only for the case (2b) + # and are the initial Relay lines included in the corresponding + # target composite functon + funcs_dict = {} + + # Here we group together initial Relay lines from the one composite + counter = itertools.count() + for s in initial_relay_astext: + result = re.search( + r"(compiler_name: )(.*)(, op_name: )(.*)(, func_id: )((.*)(?=;)|(.*))", s + ) + if result: + target_name = result.group(2) + op_name = result.group(4) + func_id = result.group(6) + s = re.sub(r", compiler_name: (.*)", "", s).lstrip() + target_statistic[target_name] += 1 + + # create an identifier for each "unknown" case to keep the lines order + if func_id == "unknown": + func_id = str(next(counter) * -1) + + if func_id not in funcs_dict: + funcs_list.append(func_id) + funcs_dict[func_id] = [target_name] + if target_name not in ["unknown", "generic"]: + funcs_dict[func_id].append(op_name) + + funcs_dict[func_id].append(s) + + # Here we prepare the output for printing. + # The output in most cases keeps the original order of the Relay lines + # but some lines are moved to be in the corresponding composite group + output = [] + total = 0 + output.append("Total number of operators and distribution by targets") + output.append("Total:") + for target, statistic in target_statistic.items(): + total += statistic + output.append(f"{target}: {statistic}") + output[1] += f" {total}" + output[len(target_statistic) + 1] += "\n" + + for func_id in funcs_list: + _list = funcs_dict[func_id] + output.append(f"{_list[0]:10} <- {_list[1]}") + if _list[0] == "unknown": + output.append( + "Warning: The above line means that some pass(es) \ + in Relay partitioning" + ) + output.append("do not copy the span when the call is recreated") + output.append( + "and a line from initial Relay could not be associated \ + with the resulting Relay" + ) + for el in _list[2:]: + output.append(f"{_list[0]:10} <- {el}") + + if print_to_console: + print("\n" + "\n".join(output)) + if save_to_file: + file_path = os.path.abspath(dump_path) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + f.write("\n".join(output)) diff --git a/python/tvm/relay/analysis/operations_distribution.py b/python/tvm/relay/analysis/operations_distribution.py new file mode 100644 index 0000000000000..fc983c8e7eede --- /dev/null +++ b/python/tvm/relay/analysis/operations_distribution.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities that enable analyze Relay and get mappings for +the unique identifier of the Relay line to the tuple of +compiler name, composite name and composite/function identifier.""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + + +class AnalyzeOperationsDistribution(ExprVisitor): + """A visitor pass that maintains the dictionary unique_op_ids where + the tuple (compiler name, composite name, composite/function identifier) + corresponds to the unique identifier of the Relay line. + TVMC compiler adds a unique Relay line identifier as a suffix + to the call span field using the tag_suffixes pass + if the --dump-offloads option is specified. + + Attributes + ---------- + unique_op_ids : Dict[str, str, int] + Mapping the unique identifier of the Relay line obtained from + the "span" field of the Call and the tuple of compiler name, + composite name and internal composite/function identifier. + func_name : str + The name of the composite name in the partitioned Relay or + 'generic' in case the Call has not been included in any composite. + func_id : int + Internal(inside unique_op_ids) composite/function identifier. + compiler_name : str + A name of the compiler (e.g. 'ethos-u' or 'cmsis-nn') or 'generic' + in case the Call has not been included in any composite. + """ + + def __init__(self): + self.unique_op_ids = {} + self.func_name = "" + self.func_id = 1 + self.compiler_name = "" + super().__init__() + + def extract(self, call: relay.Call): + self.compiler_name = "generic" + self.func_name = "generic" + if "Compiler" in call.attrs: + self.compiler_name = call.attrs["Compiler"] + self.visit(call) + + def visit_call(self, call: relay.Call): + if isinstance(call.op, tvm.ir.Op): + if call.span: + src = call.span.source_name.name + self.unique_op_ids[src] = [self.compiler_name, self.func_name, self.func_id] + if self.func_name == "generic": + self.func_id += 1 + if isinstance(call.op, relay.Function): + self.func_name = call.op.attrs["Composite"] + self.func_id += 1 + super().visit_call(call) + + +def analyze_operations_distribution(mod): + """Traverses the partitioned graph to get the unique identifier + of the Relay line from the Call's span field. + The result is maintained in the dictionary unique_op_ids where + the unique indicator obtained from the op's span corresponds to + the tuple (compiler name, composite name, composite/function identifier). + With this information we can annotate the textual representation + of the initial Relay by indicating into which target composite + and function the operators are converted + + Parameters + ---------- + mod : tvm.ir.IRModule + The partitioned Relay graph usually obtained with + partition_for_ function + + Returns + ------- + unique_op_ids : Dict[str, str, int] + Mapping from the unique identifier of the Relay line to the tuple of + compiler name, composite name, internal composite/function + identifier. + """ + analyze = AnalyzeOperationsDistribution() + for _, func in mod.functions.items(): + analyze.extract(func) + return analyze.unique_op_ids diff --git a/python/tvm/relay/transform/suffixes.py b/python/tvm/relay/transform/suffixes.py new file mode 100644 index 0000000000000..e2f7a3c224c1c --- /dev/null +++ b/python/tvm/relay/transform/suffixes.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Add suffix to the relay.Call's span fields" +from collections import defaultdict + +import tvm + +from ..expr_functor import ExprMutator +from .. import expr as _expr + + +class _SuffixTagger(ExprMutator): + """A pass to traverse the Relay graph to add suffix to the call's span fields. + This making span an unique indicator of a Relay line and we can use it to + obtain the mapping between the Relay that gets generated from a relay frontend + and the Relay after partitioning. + """ + + def __init__(self): + ExprMutator.__init__(self) + # key: span or source name, value: counter, indexed from 0 + self.lookup = defaultdict(int) + self.suffix = "_PART_" + # a set to record hashes of an expressions which spans have been already rewritten + self.hashes = set() + + def _tag_suffix(self, span): + # To avoid error once we introduce the SequentialSpan in the future + """https://discuss.tvm.apache.org/ + t/pre-rfc-tvm-explorer-infrastructure/13457#pass-source-information-builder-6 + """ + # Don't need this if currently + if isinstance(span, tvm.relay.Span): + ori_name = span.source_name.name + new_name = ori_name + self.suffix + str(self.lookup[ori_name]) + self.lookup[ori_name] += 1 + return tvm.relay.Span( + tvm.relay.SourceName(new_name), + span.line, + span.end_line, + span.column, + span.end_column, + ) + return span + + def visit(self, expr): + if hasattr(expr, "span"): + return super().visit(expr) + return expr + + def visit_call(self, call): + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + if tvm.ir.structural_hash(call) not in self.hashes: + self.hashes.add(tvm.ir.structural_hash(call)) + expr__ = _expr.CallWithFields( + call, + new_op, + new_args, + call.attrs, + call.type_args, + None, + self._tag_suffix(call.span), + ) + else: + expr__ = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return expr__ + + +def tag_suffixes(mod): + """Traverses the Relay graph to add suffix to the call's span fields. + That making span as an unique indicator of a Relay call and we can use it to + obtain the mapping between the offloaded result and the frontend operators. + + Parameters + ---------- + tvm.ir.IRModule + The IRModule that gets generated from a relay frontend. + + Returns + ------- + tvm.ir.IRModule + The IRModule with call's span fields tagged with suffixes. + """ + tagger = _SuffixTagger() + for global_var, func in mod.functions.items(): + func = tagger.visit(func) + mod.update_func(global_var, func) + return mod diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index a18d698e54266..1df2ac76b5b4b 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -286,59 +286,26 @@ def RemoveStoreUndef(): return _ffi_api.RemoveStoreUndef() # type: ignore -def BF16Legalize(): - """Legalize bf16 typed Ops. - Runs BF16Promote, BF16CastElimination and BF16TypeLowering +def BF16ComputeLegalize(): + """Legalize bf16 compute Ops. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Legalize() # type: ignore + return _ffi_api.BF16ComputeLegalize() # type: ignore -def BF16Promote(): - """Promote bf16 to fp32. Add a cast to fp32 - before Ops, then add a cast back to bf16. +def BF16StorageLegalize(): + """Legalize bf16 storage types to u16. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.BF16Promote() # type: ignore - - -def BF16CastElimination(): - """Eliminate verbose casting between fp32 and bf16 - Checks if the AST has the pattern: - castto32(castto16(some_fp32_op(...))) - The verbose casting is generated by BF16Promote for multiple - bf16 Ops in a row. e.g.: - X[i] + Y[i] + T[i] => - bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) - After this pass: - bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.BF16CastElimination() # type: ignore - - -def BF16TypeLowering(): - """Replace all bf16 type with uint16. Also lower the casting - between fp32 and bf16 - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.BF16TypeLowering() # type: ignore + return _ffi_api.BF16StorageLegalize() # type: ignore def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): diff --git a/python/tvm/topi/hexagon/tensor_intrin.py b/python/tvm/topi/hexagon/tensor_intrin.py index 3e9fd47b0fc62..24bbacf37cc8c 100644 --- a/python/tvm/topi/hexagon/tensor_intrin.py +++ b/python/tvm/topi/hexagon/tensor_intrin.py @@ -22,44 +22,165 @@ from tvm import te +def get_lanes(dtype: str): + if "x" not in dtype: + return 1 + + _, lanes = dtype.split("x") + return int(lanes) + + +def is_vector_type(dtype: str): + return get_lanes(dtype) != 1 + + +def is_power_of_2(n: int): + return (n & (n - 1) == 0) and n != 0 + + +def _adapt_to_highest_lanes(*args, intrinsic=None, intrinsic_lanes: int = 0): + """Apply provided lowering intrinsic to arguments with longer vector data type. + + This wrapper will do next actions: + * Split each argument into chunks with size equal intrinsic_lanes + * Apply provided intrinsic for each argument chunk + * Concatenate results + + Parameters + ---------- + args: List[PrimExpr] + List of arguments. Each arg expression should have vector type with lanes + equal `intrinsic_lanes * 2**n`. + + intrinsic: callable + Intrinsic implementation to apply. + + intrinsic_lanes: int + Vector length required by intrinsic implementation. + + Returns + ------- + res : PrimExpr + Resulting expression. + """ + + def split_args(args_set): + res_args_set = [] + for args_chunk in args_set: + res_args_chunk_l = [] + res_args_chunk_h = [] + for arg_chunk in args_chunk: + element, lanes = arg_chunk.dtype.split("x") + res_arg_chunk_dtype = f"{element}x{int(lanes) // 2}" + + res_args_chunk_l.append(tvm.tir.op.vectorlow(res_arg_chunk_dtype, arg_chunk)) + res_args_chunk_h.append(tvm.tir.op.vectorhigh(res_arg_chunk_dtype, arg_chunk)) + res_args_set += [res_args_chunk_l, res_args_chunk_h] + + return res_args_set + + def concat_args(res_chunks): + merged_res_chunks = [] + for i in range(0, len(res_chunks), 2): + arg_chunk_l = res_chunks[i] + arg_chunk_h = res_chunks[i + 1] + element, lanes = arg_chunk_l.dtype.split("x") + res_arg_chunk_dtype = f"{element}x{int(lanes) * 2}" + + merged_res_chunks.append( + tvm.tir.op.vectorcombine(res_arg_chunk_dtype, arg_chunk_l, arg_chunk_h) + ) + + return merged_res_chunks + + num_chunks = None + for arg in args: + _, lanes = arg.dtype.split("x") + lanes = int(lanes) + assert lanes % intrinsic_lanes == 0 + if num_chunks is None: + assert is_power_of_2(lanes // intrinsic_lanes) + num_chunks = lanes // intrinsic_lanes + + assert num_chunks == lanes // intrinsic_lanes + + # Split arguments + lowered_args = [args] + while len(lowered_args) != num_chunks: + lowered_args = split_args(lowered_args) + + # Intrinsic application + lowered_res = [] + for l_arg in lowered_args: + res = intrinsic(*l_arg) + lowered_res.append(res) + + # Result concatenation + while len(lowered_res) != 1: + lowered_res = concat_args(lowered_res) + + return lowered_res[0] + + def _q_multiply_shift_hexagon(op): """ Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. """ - x = op.args[0] - y = op.args[1] - fractional_bits = op.args[2] - shift = op.args[3] - - # Don't use this intrinsic if we don't have a int32x32 vector - # or if we are not multiplying q31 numbers - if x.dtype != "int32x32" or fractional_bits.value != 31: - return op + arg_x = op.args[0] + arg_fractional_bits = op.args[2] - # Case 1, shift is negative - mul_e_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_1, x, y - ) - fixup = 1 << (-shift - 1) - round_mul = mul_o_1 + fixup - out_negative_shift = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"), round_mul, shift - ) + # Don't use this intrinsic if we are not multiplying q31 numbers + if arg_fractional_bits.value != 31: + return op - # Case 2, shift is positive - x = x * (1 << (shift)) - mul_e_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y - ) + x_lanes = get_lanes(arg_x.dtype) + if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32): + return op - # Select depending on the shift - return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2) + # pylint: disable=unused-argument + def intrinsic_lowering_32(x, y, fractional_bits, shift): + lowered_dtype = "int32x32" + + # Case 1, shift is negative + mul_e_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + x, + y, + ) + fixup = 1 << (-shift - 1) + round_mul = mul_o_1 + fixup + out_negative_shift = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vaslwv.128B", + tvm.tir.const(2, "uint32"), + round_mul, + shift, + ) + + # Case 2, shift is positive + x = x * (1 << (shift)) + mul_e_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_2, + x, + y, + ) + + # Select depending on the shift + return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2) + + return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_lowering_32, intrinsic_lanes=32) register_intrin_lowering( @@ -72,65 +193,87 @@ def _q_multiply_shift_per_axis_hexagon(op): Implementation of q_multiply_shift_per_axis through hexagon intrinsics vmpyewuh and vmpyowh when q == 31. """ - x = op.args[0] - y = op.args[1] - left_shift = op.args[2] - right_shift = op.args[3] - fractional_bits = op.args[4] - is_lshift_required = op.args[5] - is_rshift_required = op.args[6] - - # Don't use this intrinsic if we don't have a int32x32 vector - # or if we are not multiplying q31 numbers - if x.dtype != "int32x32" or fractional_bits.value != 31: + arg_x = op.args[0] + arg_fractional_bits = op.args[4] + arg_is_lshift_required = op.args[5] + arg_is_rshift_required = op.args[6] + + # Don't use this intrinsic if we are not multiplying q31 numbers + if arg_fractional_bits.value != 31: + return op + + x_lanes = get_lanes(arg_x.dtype) + if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32): return op # Don't use this intrinsic when we need do both: left and right shifts. # For now it is not clear how to implement this case through vector HVX instructions without # accuracy drop. - if is_rshift_required.value and is_lshift_required.value: + if arg_is_rshift_required.value and arg_is_lshift_required.value: return op - # Case 1: do the left shift - shifted_x = x << left_shift - mul_e_1 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y - ) - left_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, - "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", - tvm.tir.const(3, "uint32"), - mul_e_1, - shifted_x, - y, - ) - - # Case 2: do the right shift - mul_e_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - mul_o_2 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_2, x, y - ) - fixup = 1 << (right_shift - 1) - round_mul = mul_o_2 + fixup - right_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"), round_mul, right_shift - ) - - # Case 3: do neither right nor left shift - mul_e_3 = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y - ) - no_shift_out = tvm.tir.call_llvm_intrin( - op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3, "uint32"), mul_e_3, x, y - ) - - return tvm.tir.Select( - tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), - no_shift_out, - tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), - ) + # pylint: disable=unused-argument + def intrinsic_impl_32( + x, y, left_shift, right_shift, fractional_bits, is_lshift_required, is_rshift_required + ): + lowered_dtype = "int32x32" + + # Case 1: do the left shift + shifted_x = x << left_shift + mul_e_1 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), shifted_x, y + ) + left_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_1, + shifted_x, + y, + ) + + # Case 2: do the right shift + mul_e_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + mul_o_2 = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_2, + x, + y, + ) + fixup = 1 << (right_shift - 1) + round_mul = mul_o_2 + fixup + right_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vasrwv.128B", + tvm.tir.const(2, "uint32"), + round_mul, + right_shift, + ) + + # Case 3: do neither right nor left shift + mul_e_3 = tvm.tir.call_llvm_intrin( + lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"), x, y + ) + no_shift_out = tvm.tir.call_llvm_intrin( + lowered_dtype, + "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", + tvm.tir.const(3, "uint32"), + mul_e_3, + x, + y, + ) + + return tvm.tir.Select( + tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)), + no_shift_out, + tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out), + ) + + return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_impl_32, intrinsic_lanes=32) register_intrin_lowering( diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 3458376848f13..569864a29edb8 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -218,7 +218,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -605,6 +605,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } + mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::SplitHostDevice()); return transform::Sequential(mixed_pass_list); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 2d1507c8994d9..d654e467f1e7e 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -138,7 +138,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::InjectVirtualThread()); diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index c3cc0ef60152a..2797ee44735c4 100644 --- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -144,13 +144,40 @@ void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedA } } +int CalculateNumRewritableLoops(const Array& loop_srefs, + const std::vector& loop_types) { + int rw_loops_num = 0; + ICHECK_EQ(loop_srefs.size(), loop_types.size()); + for (size_t i = 0; i < loop_srefs.size(); ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); + if (HasAnnOrBinding(loop)) { + continue; + } + // Cannot vectorize reduce axis + if (loop_types[i] != IterVarType::kDataPar) { + continue; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + continue; + } + // Check if the loop extent is valid + if (GetLoopIntExtent(loop_sref) == nullptr) { + continue; + } + ++rw_loops_num; + } + return rw_loops_num; +} + void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, const Array& loop_rvs, ParsedAnnotation* parsed) { StmtSRef block_sref = sch->GetSRef(block_rv); if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { return; } - int n_loops = loop_rvs.size(); + const int n_loops = loop_rvs.size(); if (n_loops == 0) { parsed->max_parallel_extent = -1; parsed->max_vectorize_extent = -1; @@ -226,6 +253,10 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, } max_fusible = std::min(max_fusible, fusible); } + + // Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization. + int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types); + // Calculate the parallelize extent if (parsed->max_parallel_extent != -1) { int max_extent = parsed->max_parallel_extent; @@ -290,10 +321,17 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, num_fusible = -1; } } - // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { - parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // - n_loops - parsed->num_vectorize_loops); + if (max_rw_loops == n_loops && max_fusible == n_loops) { + // All loops can be fused, parallelized and vectorized + parsed->num_parallel_loops = n_loops; + parsed->num_vectorize_loops = n_loops; + } else { + // Prefer num_vectorize to num_parallel + parsed->num_parallel_loops = + std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops); + } } } @@ -317,6 +355,21 @@ bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, Block return false; } +void RewriteFuseSplitParallelVectorize(const Schedule& sch, Array* loop_rvs, int vec_len) { + size_t n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->end()}); + Array split = sch->Split(fused, {NullOpt, Integer(vec_len)}); + ICHECK_EQ(split.size(), 2); + const LoopRV& outer = split[0]; + const LoopRV& inner = split[1]; + sch->Parallel(outer); + sch->Vectorize(inner); + for (size_t i = 0; i < n_loops - 1; ++i) { + loop_rvs->Set(i, outer); + } + loop_rvs->Set(n_loops - 1, inner); +} + void RewriteParallel(const Schedule& sch, size_t n, Array* loop_rvs) { ICHECK_LE(n, loop_rvs->size()); LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); @@ -364,13 +417,19 @@ class RewriteParallelVectorizeUnrollNode : public PostprocNode { } tir::ParsedAnnotation parsed = parsed_root; tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); - // Parallel - if (parsed.num_parallel_loops > 0) { - tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); - } - // Vectorize - if (parsed.num_vectorize_loops > 0) { - tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + const int loops_num = loop_rvs.size(); + if (parsed.num_parallel_loops == loops_num && parsed.num_vectorize_loops == loops_num) { + // Fuse, split, vectorize and parallelize + tir::RewriteFuseSplitParallelVectorize(sch, &loop_rvs, parsed.max_vectorize_extent); + } else { + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } } // AutoUnroll if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 0013106b09e80..3240496afe78f 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -169,7 +169,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::BF16ComputeLegalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); // Phase 2 diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index c6ed7af9ff031..f82014d5d1f56 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -206,6 +206,7 @@ class ExtractConstantsMutator : public MixedModeMutator { final_call = Call(new_func, new_args); } + final_call->span = call->span; return final_call; } diff --git a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc index 71c31c303588f..0ef7091fc2899 100644 --- a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc +++ b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc @@ -138,7 +138,7 @@ class FusePadsMutator : public MixedModeMutator { auto new_conv2d_args = conv2d_call->args; new_conv2d_args.erase(new_conv2d_args.begin()); new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input); - Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}); + Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}, conv2d_call->span); return std::move(ret_call); } @@ -162,6 +162,7 @@ class FusePadsMutator : public MixedModeMutator { Function new_func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); ret_call = Call(new_func, post_call->args); + ret_call->span = call->span; } } diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index e08b61c457f9b..3bdbb5d057eb8 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -153,16 +153,17 @@ class GenerateConstantsMutator : public MixedModeMutator { // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc Array conv2d_args = {conv2d_call->args[0], conv2d_kernel, conv2d_call->args[2], multiplier_const, conv2d_call->args[4], weight_scale}; - Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}); + Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}, conv2d_call->span); if (bias_add_call) { - ret_call = - Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {}); + ret_call = Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, + {}, bias_add_call->span); } Array requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3], requantize_call->args[4]}; - ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {}); + ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {}, + requantize_call->span); if (clip_call) { - ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); + ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}, clip_call->span); } return std::move(ret_call); } @@ -198,6 +199,7 @@ class GenerateConstantsMutator : public MixedModeMutator { } } + final_call->span = call->span; return final_call; } diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc index 0e2036505b6f0..f64f485bfda29 100644 --- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -83,6 +83,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { FreeTypeVars(new_body, mod_), func->attrs); mod_->Update(global_var, new_func); final_call = Call(global_var, call->args); + final_call->span = call->span; } // Substitute scalar constant with tensor constant in the call to composite function. @@ -140,7 +141,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { String arg_name = scalar_arg.as()->name_hint(); new_args.Set(i, Var(arg_name, tensor_arg->checked_type_)); } - return Call(call->op, new_args, call->attrs, {}); + return Call(call->op, new_args, call->attrs, {}, call->span); } // Replaces scalar constant with a tensor constant with same shape as that of the neighbouring @@ -187,7 +188,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { if (new_args[0].same_as(new_args[1])) { new_args.erase(new_args.begin()); } - return Call(new_func, new_args); + return Call(new_func, new_args, Attrs(), {}, call->span); } private: diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 3f1985b7ddfa5..eb6f9ec004322 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -258,6 +258,7 @@ class AnnotateTargetRewriter : public ExprRewriter { Array compiler_begins = std::get<1>(target_n_args); Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); new_call->checked_type_ = pre->checked_type_; + new_call->span = pre->span; // Update the target map. op_expr_to_target_[new_call] = target; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index f6b694cb7cb3d..24dbfebe5543d 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -46,7 +46,6 @@ runtime::Module Build(IRModule mod, Target target) { .value()) { mod = tir::transform::SkipAssert()(mod); } - auto target_attr_map = tvm::TargetKind::GetAttrMap("TIRToRuntime"); if (target_attr_map.count(target->kind)) { return target_attr_map[target->kind](mod, target); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 365beb5f5dd02..7c32f3cfa1248 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -828,6 +828,10 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; + // TODO(tvm-team): consider add native support + ICHECK(!from.is_bfloat16()) << "BF16 needs to be storaged lowered first"; + ICHECK(!to.is_bfloat16()) << "BF16 needs to be storaged lowered first"; + if (to.is_handle()) { return builder_->CreateBitCast(value, target); } else if (to.is_uint() && to.bits() == 1) { diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 874992578198e..50dcd7402a472 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -325,7 +325,6 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { if (tm->getTargetTriple().isOSDarwin()) { module_->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2); } - std::string verify_errors_storage; llvm::raw_string_ostream verify_errors(verify_errors_storage); LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 613b1d0847018..525ee95f4117c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -382,33 +382,6 @@ std::string CodeGenOpenCL::CastTo(std::string value, DataType target) { return os.str(); } -void CodeGenOpenCL::VisitStmt_(const BufferStoreNode* op) { - if (auto call = op->value.as()) { - if (call->op.same_as(builtin::texture2d_load())) { - need_texture_ssa_ = false; - // If storing a texture load into a buffer, don't use an - // intermediate local unless the buffer allocation is a - // single element selected from the texture read. - auto it = allocation_size_.find(op->buffer->data.get()); - if (it != allocation_size_.end() && it->second == 1) { - need_texture_ssa_ = true; - } - } - } - CodeGenC::VisitStmt_(op); - need_texture_ssa_ = true; -} - -void CodeGenOpenCL::VisitExpr_(const CastNode* op, std::ostream& os) { - if (auto call = op->value.as()) { - if (call->op.same_as(builtin::texture2d_load())) { - need_texture_ssa_ = false; - } - } - CodeGenC::VisitExpr_(op, os); - need_texture_ssa_ = true; -} - void CodeGenOpenCL::VisitStmt_(const AllocateNode* op) { allocation_size_.insert({op->buffer_var.get(), op->ConstantAllocationSize() * op->dtype.lanes()}); CodeGenC::VisitStmt_(op); @@ -472,20 +445,15 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[2], ss); ss << ")))"; - // Only use local SSA if texture is not already being stored - if (need_texture_ssa_) { - std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); - if (op->args.back().as()) { - os << rhs; - } else { - os << "(("; - this->PrintType(op->dtype.with_lanes(1), os); - os << "*)&" << rhs << ")["; - this->PrintExpr(op->args.back(), os); - os << "]"; - } + std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(4)); + if (op->args.back().as()) { + os << rhs; } else { - os << ss.str(); + os << "(("; + this->PrintType(op->dtype.with_lanes(1), os); + os << "*)&" << rhs << ")["; + this->PrintExpr(op->args.back(), os); + os << "]"; } } else if (op->op.same_as(builtin_call_extern_)) { auto func = Downcast(op->args[0]); diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index 05734b6a54ebe..8b365f85d6e66 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -66,9 +66,7 @@ class CodeGenOpenCL final : public CodeGenC { void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitStmt_(const BufferStoreNode* op) final; // NOLINT(*) // overload min and max to avoid ambiguous call errors void VisitExpr_(const MinNode* op, std::ostream& os) final; @@ -86,9 +84,6 @@ class CodeGenOpenCL final : public CodeGenC { // Whether to enable sampler or sampler-less texture reads, // where the choice depends on the OpenCL version used. bool enable_compliant_texture_reads_{false}; - // Key to disable use of texture SSA in certain scenarios. For example, - // when loaded value is stored directly to a user declared l-value buffer - bool need_texture_ssa_{true}; // Mapping from buffer to allocation size. // Useful to track when a scalar store of a vectorized texture load is required. std::unordered_map allocation_size_; diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index d158a001b2d83..ce9d5eaaf8385 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -208,10 +208,15 @@ double EstimateTIRFlops(const Stmt& stmt) { double EstimateTIRFlops(const IRModule& mod) { FlopEstimator counter; TResult result; - VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) { - result += counter.VisitStmt(f->body); // + double cached_result = 0; + VisitPrimFuncs(mod, [&result, &counter, &cached_result](const PrimFuncNode* f) { + if (auto cached = f->attrs.GetAttr("estimated_flops")) { + cached_result += cached.value()->value; + } else { + result += counter.VisitStmt(f->body); // + } }); - return PostprocessResults(result); + return PostprocessResults(result) + cached_result; } TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 828ab010831f8..4439a9c3d711a 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -324,6 +324,8 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { // reinterpret PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { if (value.dtype() == t) return value; + ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) + << "Bitcast requires size match " << t << " vs " << value.dtype(); return tir::Call(t, tir::builtin::reinterpret(), {value}, span); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 2235af7302142..bb2abc559d2c6 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1095,8 +1095,17 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& Array initial_indices; Map var_map; + std::optional index_dtype = std::nullopt; for (size_t i = 0; i < args.size(); ++i) { + if (index_dtype.has_value()) { + ICHECK_EQ(*index_dtype, args[i]->dtype) + << "Buffer index " << args[i] << " has dtype " << args[i]->dtype + << ", but previous index for the same buffer access used index type " << *index_dtype; + } else { + index_dtype = args[i]->dtype; + } + if (args[i]->dtype != initial_indices_orig[i].dtype()) { auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype); initial_indices.push_back(new_idx); @@ -1108,8 +1117,13 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& if (!var_map.empty()) { auto final_indices = index_map->final_indices.Map([&](PrimExpr index) { - return SubstituteWithDataTypeLegalization(index, - [&](const Var& var) { return var_map.Get(var); }); + if (auto* ptr = index.as()) { + ICHECK(index_dtype.has_value()); + return tir::make_const(*index_dtype, ptr->value); + } else { + return SubstituteWithDataTypeLegalization(index, + [&](const Var& var) { return var_map.Get(var); }); + } }); Optional opt_inverse_index_map = Downcast>(index_map->inverse_index_map); diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 2fc3bd2dca43b..c785b732abc3f 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -184,7 +184,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) == IntImm(DataType::UInt(16), buffer->dtype.lanes())); if (!(buffer->dtype == DataType::Int(1) || buffer->dtype == DataType::Int(4) || - buffer->dtype == DataType::UInt(4) || buffer->dtype == DataType::UInt(16))) { + buffer->dtype == DataType::UInt(4))) { auto type_msg = tvm::tir::StringImm(type_err_msg.str()); asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); diff --git a/src/tir/transforms/bf16_legalize.cc b/src/tir/transforms/bf16_legalize.cc index 3b89558622e99..99fad558cfe2b 100644 --- a/src/tir/transforms/bf16_legalize.cc +++ b/src/tir/transforms/bf16_legalize.cc @@ -25,173 +25,198 @@ #include #include #include +#include #include #include #include -#include "../../arith/ir_mutator_with_analyzer.h" -#include "../../arith/ir_visitor_with_analyzer.h" - namespace tvm { namespace tir { -using arith::Analyzer; -using arith::IRMutatorWithAnalyzer; - -class BF16PromoteRewriter : public StmtExprMutator { +// NOTE: do not touch buffer on function boundary +// remap internal bf16 buffer to f32 if they meet the following condition +// - constant allocation size +// - do not have raw pointer access to the buffer +// +// populate the buffer_remap and var_remap accordingly. +class BF16ComputeLegalizePlanner : public StmtExprVisitor { public: - BF16PromoteRewriter() {} - - Stmt operator()(Stmt s) { return VisitStmt(s); } - - PrimExpr VisitExpr_(const AddNode* op) final; - PrimExpr VisitExpr_(const SubNode* op) final; - PrimExpr VisitExpr_(const MulNode* op) final; - PrimExpr VisitExpr_(const DivNode* op) final; - PrimExpr VisitExpr_(const MinNode* op) final; - PrimExpr VisitExpr_(const MaxNode* op) final; - PrimExpr VisitExpr_(const LTNode* op) final; - PrimExpr VisitExpr_(const LENode* op) final; - PrimExpr VisitExpr_(const GTNode* op) final; - PrimExpr VisitExpr_(const GENode* op) final; - PrimExpr VisitExpr_(const CallNode* op) final; -}; + BF16ComputeLegalizePlanner( + std::unordered_map* buffer_remap, + std::unordered_map* var_remap) + : buffer_remap_(buffer_remap), var_remap_(var_remap) {} + + // run planning to populate buffer remap and var remap. + void Plan(PrimFunc func) { + this->VisitStmt(func->body); + // if there are opaque var access, then we cannot + // do remap of var and buffer, post-hoc remove these items. + for (Var var : opaque_var_access_) { + auto it = var_remap_->find(var); + if (it != var_remap_->end()) { + var_remap_->erase(it); + } + } + Array drop_buffers; + for (auto kv : *buffer_remap_) { + if (opaque_var_access_.count(kv.first->data)) { + drop_buffers.push_back(kv.first); + } + } + for (Buffer buffer : drop_buffers) { + auto it = buffer_remap_->find(buffer); + ICHECK(it != buffer_remap_->end()); + buffer_remap_->erase(it); + } + } -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \ - PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ - PrimExpr origin_a = this->VisitExpr(op->a); \ - PrimExpr origin_b = this->VisitExpr(op->b); \ - bool a_is_bfloat16 = origin_a->dtype.is_bfloat16(); \ - bool b_is_bfloat16 = origin_b->dtype.is_bfloat16(); \ - bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16; \ - bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16); \ - if (none_bfloat16) { \ - return GetRef(op); \ - } \ - DataType float32_dtype(kDLFloat, 32, 1); \ - PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : origin_a; \ - PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : origin_b; \ - PrimExpr result = FUNC(float32_a, float32_b); \ - DataType bfloat16_dtype(kDLBfloat, 16, 1); \ - bool do_cast = both_bfloat16 && NEEDCAST; \ - return do_cast ? Cast(bfloat16_dtype, result) : result; \ - } - -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false) - -PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) { - Array args; - for (auto& arg : op->args) { - PrimExpr x = this->VisitExpr(arg); - if (x.dtype().is_bfloat16()) { - DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes()); - args.push_back(Cast(fp32_dtype, {x}, op->span)); - } else { - args.push_back(x); + void VisitStmt_(const AllocateNode* op) final { + // remap all intermediate constant buffr to fp32 + if (op->dtype.is_bfloat16() && op->ConstantAllocationSize() != 0) { + DataType dtype = DataType::Float(32, op->dtype.lanes()); + Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); + (*var_remap_)[op->buffer_var] = buffer_var; } + return StmtExprVisitor::VisitStmt_(op); } - if (op->dtype.is_bfloat16()) { - DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes()); - PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span); - return Cast(op->dtype, {result_fp32}, op->span); - } else { - return Call(op->dtype, op->op, args, op->span); + + void VisitStmt_(const BufferStoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + this->PopulateBufferRemap(op->buffer); } -} -/* - * Eliminate verbose casting between fp32 and bf16 - * Checks if the AST has the pattern: - * castto32(castto16(some_fp32_op(...))) - * The verbose casting is generated by BF16Promote for multiple - * bf16 Ops in a row. e.g.: - * X[i] + Y[i] + T[i] => - * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) - * After this pass: - * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) - */ -class BF16CastEliminationRewriter : public StmtExprMutator { - public: - BF16CastEliminationRewriter() {} + void VisitExpr_(const BufferLoadNode* op) final { + StmtExprVisitor::VisitExpr_(op); + this->PopulateBufferRemap(op->buffer); + } - Stmt operator()(Stmt s) { return VisitStmt(s); } + void VisitStmt_(const DeclBufferNode* op) final { + StmtExprVisitor::VisitStmt_(op); + this->PopulateBufferRemap(op->buffer); + } - PrimExpr VisitExpr_(const CastNode* op) final { - auto op_val = StmtExprMutator::VisitExpr(op->value); - if (op->dtype.is_float() && op->dtype.bits() == 32) { - // if is cast_to_fp32, check if op->value is cast_to_fp16 - // and op->value->value is a float32 - if (auto innercast = op_val.as()) { - if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() && - innercast->value->dtype.bits() == 32) { - return innercast->value; - } - } + void VisitExpr_(const VarNode* op) final { + StmtExprVisitor::VisitExpr_(op); + Var buffer_var = GetRef(op); + if (buffer_var.dtype().is_handle()) { + opaque_var_access_.insert(buffer_var); } - if (op->value.same_as(op_val)) return GetRef(op); - return Cast(op->dtype, op_val); } -}; -union FloatCaster { - uint32_t u32; - float f32; + private: + void PopulateBufferRemap(Buffer buf) { + auto var_it = var_remap_->find(buf->data); + if (var_it == var_remap_->end()) return; + + Buffer new_buffer(var_it->second, DataType::Float(32, buf->dtype.lanes()), buf->shape, + buf->strides, buf->elem_offset, buf->name, buf->data_alignment, + buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + (*buffer_remap_)[buf] = new_buffer; + } + + std::unordered_map* buffer_remap_; + std::unordered_map* var_remap_; + std::unordered_set opaque_var_access_; }; -uint16_t RoundToNearestEven(float src) { - if (std::isnan(src)) { - return UINT16_C(0x7FC0); - } else { - FloatCaster caster; - caster.f32 = src; - uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF); - return static_cast((caster.u32 + rounding_bias) >> 16); +#define DEFINE_BIOP_EXPR_LEGALIZE(OP, FUNC) \ + PrimExpr VisitExpr_(const OP* op) final { \ + PrimExpr origin_a = PromoteBF16ToF32(this->VisitExpr(op->a)); \ + PrimExpr origin_b = PromoteBF16ToF32(this->VisitExpr(op->b)); \ + \ + if (origin_a.same_as(op->a) && origin_b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(origin_a, origin_b); \ + } \ } -} -/* - * Lower the bf16 type to int16 - * Lower cast between bf16 and fp32 - * Lower bf16 FloatImm to int16 - */ -class BF16LowerRewriter : public StmtExprMutator { +// NOTE: Legalize the BF16 computations +// to floating point computations and only keeps the +// bf16 storage which can further be legalized by BF16StorageLegalizer +// BF16StorageLegalizer will be called at a much later time +// point in the TIR lowering phases. +class BF16ComputeLegalizer : public StmtExprMutator { public: - BF16LowerRewriter() {} - - using StmtExprMutator::operator(); + PrimFunc Legalize(PrimFunc func) { + BF16ComputeLegalizePlanner planner(&buffer_remap_, &var_remap_); + planner.Plan(func); + auto* n = func.CopyOnWrite(); + n->body = this->VisitStmt(std::move(n->body)); + return func; + } + protected: PrimExpr VisitExpr_(const CastNode* op) final { - PrimExpr op_val = StmtExprMutator::VisitExpr(op->value); - DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes()); - DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes()); - if (op->value->dtype.is_bfloat16()) { // cast from bf16 - PrimExpr uint32_v = Cast(uint32_dtype, op_val); - PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), {uint32_v << 16}); - bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32; - return is_to_float32 ? float32_v : Cast(op->dtype, float32_v); - } else if (op->dtype.is_bfloat16()) { // cast to bf16 - bool is_from_float32 = op->value->dtype.is_float() && op->value->dtype.bits() == 32; - PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, op_val); - PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), {float32_v}); - DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes()); - /* the following TIR is equivalent to the C++ code below: - uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); - return static_cast((U32 + rounding_bias) >> 16);*/ - PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); - return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); - } - if (op->value.same_as(op_val)) return GetRef(op); - return Cast(op->dtype, op_val); + auto op_val = PromoteBF16ToF32(this->VisitExpr(op->value)); + + // all casts to BF16 becomes f32 + if (op->dtype.is_bfloat16()) { + return cast(DataType::Float(32, op->dtype.lanes()), op_val); + } + + if (op_val.same_as(op->value)) { + return GetRef(op); + } else { + return cast(op->dtype, op_val); + } + } + + PrimExpr VisitExpr_(const SelectNode* op) final { + PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr true_value = PromoteBF16ToF32(this->VisitExpr(op->true_value)); + PrimExpr false_value = PromoteBF16ToF32(this->VisitExpr(op->false_value)); + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && + false_value.same_as(op->false_value)) { + return GetRef(op); + } else { + return Select(condition, true_value, false_value); + } + } + + PrimExpr VisitExpr_(const BroadcastNode* op) final { + PrimExpr value = PromoteBF16ToF32(this->VisitExpr(op->value)); + if (value.same_as(op->value)) { + return GetRef(op); + } else { + return Broadcast(value, op->lanes); + } + } + + PrimExpr VisitExpr_(const ShuffleNode* op) final { + auto fexpr = [this](const PrimExpr& e) { return PromoteBF16ToF32(this->VisitExpr(e)); }; + auto vectors = op->vectors.Map(fexpr); + if (vectors.same_as(op->vectors)) { + return GetRef(op); + } else { + return Shuffle(vectors, op->indices); + } + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // presertve reinterpret() behavior. + if (op->op.same_as(builtin::reinterpret())) { + return StmtExprMutator::VisitExpr_(op); + } + // update normal computations to return f32 instead. + auto fmutate = [this](const PrimExpr& e) { return PromoteBF16ToF32(this->VisitExpr(e)); }; + Array args = op->args.Map(fmutate); + if (op->dtype.is_bfloat16()) { + return Call(DataType::Float(32, op->dtype.lanes()), op->op, args); + } + if (args.same_as(op->args)) { + return GetRef(op); + } else { + return Call(op->dtype, op->op, args); + } + } + + PrimExpr VisitExpr_(const FloatImmNode* op) final { + if (op->dtype.is_bfloat16()) { + return FloatImm(DataType::Float(32), op->value); + } + return GetRef(op); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -205,26 +230,73 @@ class BF16LowerRewriter : public StmtExprMutator { } } - Stmt VisitStmt_(const AllocateNode* op) final { - if (op->dtype.is_bfloat16()) { - DataType dtype = DataType::UInt(16, op->dtype.lanes()); - Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); - var_remap_[op->buffer_var] = buffer_var; - return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); + PrimExpr VisitExpr_(const LetNode* op) final { + PrimExpr value = PromoteBF16ToF32(op->value); + Var var = op->var; + if (value.dtype() != op->value.dtype()) { + var = op->var.copy_with_dtype(op->value.dtype()); + var_remap_[op->var] = var; + } + + PrimExpr body = VisitExpr(op->body); + + if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + return GetRef(op); } else { - return StmtExprMutator::VisitStmt_(op); + return Let(var, value, body); + } + } + + DEFINE_BIOP_EXPR_LEGALIZE(AddNode, operator+); + DEFINE_BIOP_EXPR_LEGALIZE(SubNode, operator-); + DEFINE_BIOP_EXPR_LEGALIZE(MulNode, operator*); + DEFINE_BIOP_EXPR_LEGALIZE(DivNode, div); + DEFINE_BIOP_EXPR_LEGALIZE(MinNode, min); + DEFINE_BIOP_EXPR_LEGALIZE(MaxNode, max); + DEFINE_BIOP_EXPR_LEGALIZE(LTNode, operator<); // NOLINT(*) + DEFINE_BIOP_EXPR_LEGALIZE(LENode, operator<=); + DEFINE_BIOP_EXPR_LEGALIZE(GTNode, operator>); // NOLINT(*) + DEFINE_BIOP_EXPR_LEGALIZE(GENode, operator>=); + DEFINE_BIOP_EXPR_LEGALIZE(EQNode, operator==); + DEFINE_BIOP_EXPR_LEGALIZE(NENode, operator!=); + + Stmt VisitStmt_(const LetStmtNode* op) final { + PrimExpr value = PromoteBF16ToF32(op->value); + Var var = op->var; + if (value.dtype() != op->value.dtype()) { + var = op->var.copy_with_dtype(op->value.dtype()); + var_remap_[op->var] = var; + } + Stmt body = VisitStmt(op->body); + + if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(var, value, body); } } Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); + PrimExpr value = this->VisitExpr(op->value); + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + + Array indices = op->indices.Map(fmutate); Buffer new_buf = GetRemappedBuffer(op->buffer); - if (new_buf.same_as(op->buffer)) { - return ret; + + if (value.same_as(op->value) && indices.same_as(op->indices) && new_buf.same_as(op->buffer)) { + return GetRef(op); } else { - return BufferStore(new_buf, op->value, op->indices); + if (new_buf->dtype.is_bfloat16()) { + value = CastF32ToBF16(value); + } + if (value.dtype() != new_buf->dtype) { + // this happens when buffer get rewritten to f32 + // but values remain as bf16 + ICHECK(value.dtype().is_bfloat16()); + value = cast(new_buf->dtype.with_lanes(value.dtype().lanes()), value); + } + return BufferStore(new_buf, value, indices); } } @@ -258,6 +330,35 @@ class BF16LowerRewriter : public StmtExprMutator { } } + Stmt VisitStmt_(const DeclBufferNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { + return ret; + } else { + return DeclBuffer(new_buf, op->body); + } + } + + Stmt VisitStmt_(const AllocateNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + auto it = var_remap_.find(op->buffer_var); + if (it != var_remap_.end()) { + Var remapped_var = it->second; + auto* ptr = remapped_var->type_annotation.as(); + ICHECK(ptr); + auto* prim_type = ptr->element_type.as(); + ICHECK(prim_type); + return Allocate(remapped_var, prim_type->dtype, op->extents, op->condition, op->body); + } else { + return ret; + } + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr ret = StmtExprMutator::VisitExpr_(op); op = ret.as(); @@ -270,48 +371,244 @@ class BF16LowerRewriter : public StmtExprMutator { } } - PrimExpr VisitExpr_(const FloatImmNode* op) final { + private: + /*! + * \brief promote BF16 to F32 and keep other values unchanged. + * \param value The input value. + * \return The converted value. + */ + PrimExpr PromoteBF16ToF32(PrimExpr value) { + if (!value.dtype().is_bfloat16()) return value; + if (const CastNode* cast = value.as()) { + if (cast->value.dtype() == DataType::Float(32)) return cast->value; + } + DataType f32 = DataType::Float(32, value.dtype().lanes()); + DataType u16 = DataType::UInt(16, value.dtype().lanes()); + DataType u32 = DataType::UInt(32, value.dtype().lanes()); + // reinterpret((cast(reinterpret(bf16_value)) << 16)) + return reinterpret(f32, cast(u32, reinterpret(u16, value)) << 16); + } + + /*! + * \brief Cast value to F32 to BF16 and keep other values unchanged. + * \param value The input value + * \return The converted value. + */ + PrimExpr CastF32ToBF16(PrimExpr value) { + if (!value.dtype().is_float()) return value; + ICHECK_EQ(value.dtype().bits(), 32); + DataType bf16 = DataType::BFloat(16, value.dtype().lanes()); + DataType u16 = DataType::UInt(16, value.dtype().lanes()); + DataType u32 = DataType::UInt(32, value.dtype().lanes()); + PrimExpr u32_val = reinterpret(u32, value); + + if (round_to_even_) { + PrimExpr rounding_bias = ((u32_val >> 16) & 1) + make_const(u32, 0x7FFF); + u32_val = u32_val + rounding_bias; + } + // reinterpret((cast(reinterpret(f32_value)) >> 16)) + return reinterpret(bf16, cast(u16, u32_val >> 16)); + } + + Buffer GetRemappedBuffer(Buffer buf) { + auto buf_it = buffer_remap_.find(buf); + if (buf_it != buffer_remap_.end()) { + return buf_it->second; + } + return buf; + } + + bool round_to_even_{true}; + + std::unordered_map buffer_remap_; + std::unordered_map var_remap_; +}; + +/*! + * \brief This Pass legalizes remaining BF16 storages to u16 + * + * This pass needs to happens after BF16ComputeLegalizer and serves + * as a way to support BF16 on platforms that do not have native support. + */ +class BF16StorageLegalizer : public StmtExprMutator { + public: + PrimFunc Legalize(PrimFunc func) { + ICHECK_EQ(func->buffer_map.size(), 0) << "This pass must be called after MakePackedAPI"; + auto* n = func.CopyOnWrite(); + n->params = n->params.Map([this](Var var) { return this->RemapVarDef(var); }); + n->body = this->VisitStmt(std::move(n->body)); + return func; + } + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto itr = var_remap_.find(var); + if (itr != var_remap_.end()) { + return itr->second; + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const AllocateNode* op) final { if (op->dtype.is_bfloat16()) { - return IntImm(DataType::UInt(16, op->dtype.lanes()), - RoundToNearestEven(static_cast(op->value))); + DataType dtype = DataType::UInt(16, op->dtype.lanes()); + Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); + var_remap_[op->buffer_var] = buffer_var; + return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); + } else { + return StmtExprMutator::VisitStmt_(op); } - return StmtExprMutator::VisitExpr_(op); } - void AlterBuffers(PrimFuncNode* op) { - Map new_buffer_map; - - for (auto& itr : op->buffer_map) { - auto param_var = itr.first; - auto oldbuf = itr.second; - if (oldbuf->dtype.is_bfloat16()) { - DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); - Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); - auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, - oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor, - oldbuf->buffer_type); - buffer_remap_[oldbuf] = newbuf; - var_remap_[oldbuf->data] = buffer_var; - new_buffer_map.Set(param_var, newbuf); - } else { - new_buffer_map.Set(param_var, oldbuf); + Stmt VisitStmt_(const DeclBufferNode* op) final { + Buffer buf = GetRemappedBuffer(op->buffer); + // in a rare case the buffer didn't get remapped + // because the original var is not bfloat* + // force remap here + if (buf->dtype.is_bfloat16()) { + buf = Buffer(buf->data, DataType::UInt(16, buf->dtype.lanes()), buf->shape, buf->strides, + buf->elem_offset, buf->name, buf->data_alignment, buf->offset_factor, + buf->buffer_type, buf->axis_separators, buf->span); + buffer_remap_[op->buffer] = buf; + } + Stmt body = VisitStmt(op->body); + if (buf.same_as(op->buffer) && body.same_as(op->body)) { + return GetRef(op); + } else { + return DeclBuffer(buf, body, op->span); + } + } + + PrimExpr VisitExpr_(const LetNode* op) final { + PrimExpr value = VisitExpr(op->value); + Var var = RemapVarDef(op->var); + PrimExpr body = VisitExpr(op->body); + + if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + return GetRef(op); + } else { + return Let(var, value, body); + } + } + + Stmt VisitStmt_(const LetStmtNode* op) final { + PrimExpr value = VisitExpr(op->value); + Var var = RemapVarDef(op->var); + Stmt body = VisitStmt(op->body); + + if (value.same_as(op->value) && var.same_as(op->var) && body.same_as(op->body)) { + return GetRef(op); + } else { + return LetStmt(var, value, body); + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + PrimExpr value = this->ChangeBF16ToU16(VisitExpr(op->value)); + Buffer new_buf = GetRemappedBuffer(op->buffer); + auto indices = op->indices.Map([this](PrimExpr expr) { return this->VisitExpr(expr); }); + if (new_buf.same_as(op->buffer) && indices.same_as(op->indices) && value.same_as(op->value)) { + return GetRef(op); + } else { + if (op->value.dtype().is_bfloat16()) { + ICHECK(new_buf->dtype.is_uint()); + } + return BufferStore(new_buf, value, indices); + } + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + if (auto* buffer = op->node.as()) { + auto it = buffer_remap_.find(GetRef(buffer)); + if (it != buffer_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); + } + } else if (auto* var = op->node.as()) { + auto it = var_remap_.find(GetRef(var)); + if (it != var_remap_.end()) { + return AttrStmt(it->second, op->attr_key, op->value, op->body); } } + return ret; + } - if (buffer_remap_.size() != 0) { - op->buffer_map = new_buffer_map; + Stmt VisitStmt_(const BufferRealizeNode* op) final { + LOG(FATAL) << "Do not expect buffer realize"; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + Buffer new_buf = GetRemappedBuffer(op->buffer); + if (new_buf.same_as(op->buffer)) { + return ret; + } else { + return BufferLoad(new_buf, op->indices); } } + PrimExpr VisitExpr_(const CallNode* op) final { + // remap re-interpret so un-necessary reinterpret can be skipped. + if (op->op.same_as(builtin::reinterpret())) { + PrimExpr value = VisitExpr(op->args[0]); + // sometimes the input dtype can change and we can skip. + if (value.dtype() == op->dtype) return value; + if (op->dtype.is_bfloat16()) { + return reinterpret(DataType::UInt(16, op->dtype.lanes()), value); + } + if (op->args[0].same_as(value)) { + return GetRef(op); + } else { + return reinterpret(op->dtype, value); + } + } + return StmtExprMutator::VisitExpr_(op); + } + private: + /*! + * \brief Change BF16 value to U16 value. + * \param value The input value. + * \return The converted value. + */ + PrimExpr ChangeBF16ToU16(PrimExpr value) { + if (!value.dtype().is_bfloat16()) return value; + auto* call = value.as(); + if (call && call->op.same_as(builtin::reinterpret())) { + return reinterpret(DataType::UInt(16, value.dtype().lanes()), call->args[0]); + } else { + return value; + } + } + + Var RemapVarDef(Var var) { + // remap the var + if (var.dtype().is_handle()) { + if (auto* ptr_type = var->type_annotation.as()) { + if (auto* elem_type = ptr_type->element_type.as()) { + if (elem_type->dtype.is_bfloat16()) { + Var new_var = Var(var->name_hint, + PointerType(PrimType(DataType::UInt(16, elem_type->dtype.lanes())))); + var_remap_[var] = new_var; + return new_var; + } + } + } + } + return var; + } + Buffer GetRemappedBuffer(Buffer buf) { auto buf_it = buffer_remap_.find(buf); if (buf_it != buffer_remap_.end()) { return buf_it->second; } - Buffer new_buf = buf; - auto var_it = var_remap_.find(buf->data); if (var_it != var_remap_.end()) { DataType dtype = @@ -319,6 +616,8 @@ class BF16LowerRewriter : public StmtExprMutator { new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name, buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators, buf->span); + } else { + ICHECK(!buf->dtype.is_bfloat16()) << "Cannot find var remap for " << buf; } buffer_remap_[buf] = new_buf; @@ -332,46 +631,25 @@ class BF16LowerRewriter : public StmtExprMutator { namespace transform { -Pass BF16Promote() { +Pass BF16ComputeLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = BF16PromoteRewriter()(std::move(n->body)); - return f; + // TODO(tvm-team): skip if the target supports bf16 + return BF16ComputeLegalizer().Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote); +TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16ComputeLegalize); -Pass BF16CastElimination() { +Pass BF16StorageLegalize() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = BF16CastEliminationRewriter()(std::move(n->body)); - return f; + // TODO(tvm-team): skip if the target supports bf16 + return BF16StorageLegalizer().Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination); - -Pass BF16TypeLowering() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - BF16LowerRewriter lowerer; - lowerer.AlterBuffers(n); - n->body = lowerer(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering); - -Pass BF16Legalize() { - return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize"); + return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); } -TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize); +TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index c1611a23a05fe..b12d3dd49f054 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -286,14 +286,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func) { func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); - if (undefined.size() != 0) { - std::ostringstream os; - for (Var v : undefined) { - os << " \'" << v->name_hint << "\' "; - } - os << " is not bound to any variables"; - LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); - } + ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << global_symbol << " variables " << undefined + << " are used, but are not passed in as API arguments"; func_ptr->buffer_map = Map(); func_ptr->checked_type_ = func_ptr->func_type_annotation(); diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 8fac0c302a708..119e595f59752 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -139,7 +139,6 @@ class StorageAccessVisitor : public StmtExprVisitor { // The involving threads Array env_threads_; }; - } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 844d08c66e03d..e6ebec6ac4fa1 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -28,7 +28,6 @@ import os import struct import numpy as np -import tflite.Model import math from enum import IntEnum import tensorflow as tf @@ -311,7 +310,15 @@ def representative_dataset(): converter.inference_output_type = tf.int8 tflite_graph = converter.convert() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) relay_module, params = relay.frontend.from_tflite(tflite_model) mod = partition_for_ethosu(relay_module, params) diff --git a/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py new file mode 100644 index 0000000000000..2a9d88e412108 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np + +from tvm import relay +from tests.python.contrib.test_ethosu.infra import get_tflite_graph +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tvm.relay.analysis.operations_distribution import analyze_operations_distribution +from tvm.relay.transform.suffixes import tag_suffixes + + +def test_operations_distribution_ethos(): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + + mod = tag_suffixes(mod) + mod = partition_for_ethosu(mod, params) + operations_distribution = analyze_operations_distribution(mod) + + expected = { + "Pad_PART_0": ["generic", "generic", 1], + "Conv2D2_PART_2": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Conv2D2_PART_1": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Conv2D2_PART_0": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Identity_PART_0": ["ethos-u", "ethos-u.pad2d", 4], + "Pad_1_PART_0": ["ethos-u", "ethos-u.pad2d", 5], + } + + assert operations_distribution == expected + + +def test_operations_distribution_generic(): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + dilations_out = 32 + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=dilations_out, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + + mod = tag_suffixes(mod) + mod = partition_for_ethosu(mod, params) + operations_distribution = analyze_operations_distribution(mod) + + expected = { + "Identity_PART_0": ["generic", "generic", 1], + "Pad_1_PART_0": ["generic", "generic", 2], + "Pad_PART_0": ["generic", "generic", 3], + "Conv2D2_PART_2": ["generic", "generic", 4], + "Conv2D2_PART_1": ["generic", "generic", 5], + "Conv2D2_PART_0": ["generic", "generic", 6], + } + + assert operations_distribution == expected + + +if __name__ == "__main__": + test_operations_distribution() diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py index 5eac35f2d683a..fdfe3ad2b76ef 100644 --- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py +++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py @@ -21,6 +21,7 @@ import tvm.testing from tvm import relay +from tvm import te from tvm.relay.backend import Executor from tvm.contrib.hexagon.session import Session from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET @@ -100,7 +101,7 @@ class TestFixedPointMultiply: ) @tvm.testing.requires_hexagon - def test_fixed_point_multiply(self, hexagon_session: Session, multiplier: int, shift: int): + def test_per_tensor(self, hexagon_session: Session, multiplier: int, shift: int): """Fixed point multiply test.""" ishape = (6, 32) a = relay.var("a", relay.TensorType(ishape, "int32")) @@ -169,6 +170,141 @@ def test_per_channel(self, hexagon_session: Session, in_scale_const, out_scale_c tvm.testing.assert_allclose(hexagon_output, expected_output) + vector_size = tvm.testing.parameter(32, 64, 128, 256) + + def test_per_tensor_with_lanes(self, hexagon_session: Session, vector_size): + """Test fixed point multiply with vectorization. + Vectorization size is more than hw vector length""" + ishape = [2, 256, 16] + + def q_mul_shift(shape): + x = te.placeholder(shape, name="X", dtype="int32") + out = te.compute( + shape, + lambda i, j, k: tvm.tir.q_multiply_shift( + x[i, j, k], + tvm.tir.const(1395864320, "int32"), + tvm.tir.const(31, "int32"), + tvm.tir.const(1, "int32"), + ), + name="compute", + ) + return te.create_prim_func([x, out]) + + mod = q_mul_shift(ishape) + + # Schedule with vectorization + sch = tvm.tir.Schedule(mod) + b00 = sch.get_block(name="compute", func_name="main") + fused = sch.fuse(*sch.get_loops(block=b00)) + _, v = sch.split(loop=fused, factors=[None, vector_size]) + sch.vectorize(v) + + with tvm.transform.PassContext(opt_level=3): + hex_lib = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) + host_lib = tvm.build(mod, target=tvm.target.Target("llvm")) + + asm = hex_lib.get_source("asm") + + # Check that 'vmpye' instruction was generated in asm file. + vmpye_regex = re.compile(r"v\d{1,2}.w = vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)") + assert vmpye_regex.search(asm) is not None + + # Check that 'vmpyo' instruction was generated in asm file. + vmpyo_regex = re.compile(r"v\d{1,2}.w \+= vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift") + assert vmpyo_regex.search(asm) is not None + + # Verify accuracy + a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32") + b_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32") + hex_args = [ + tvm.runtime.ndarray.array(arg, device=hexagon_session.device, mem_scope="global") + for arg in [a_np, b_np] + ] + host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]] + + hex_rt = hexagon_session.load_module(hex_lib) + hex_rt(*hex_args) + host_lib(*host_args) + + assert np.allclose(hex_args[1].numpy(), host_args[1].numpy()) + + def test_per_channel_with_lanes(self, hexagon_session: Session, vector_size): + """Test fixed point multiply with vectorization. + Vectorization size is more than hw vector length""" + a_shape = [2, 256, 16] + b_shape = [256] + + def q_mul_shift(shape): + shift_shape = [shape[1]] + x = te.placeholder(shape, name="X", dtype="int32") + y = te.placeholder(shift_shape, name="X", dtype="int32") + l_shift = te.placeholder(shift_shape, name="X", dtype="int32") + r_shift = te.placeholder(shift_shape, name="X", dtype="int32") + + out = te.compute( + shape, + lambda i, j, k: tvm.tir.q_multiply_shift_per_axis( + x[i, j, k], + y[j], + l_shift[j], + r_shift[j], + tvm.tir.const(31, "int32"), + tvm.tir.const(1, "bool"), + tvm.tir.const(0, "bool"), + ), + name="compute", + ) + return te.create_prim_func([x, y, l_shift, r_shift, out]) + + mod = q_mul_shift(a_shape) + + # Schedule with vectorization + sch = tvm.tir.Schedule(mod) + b00 = sch.get_block(name="compute", func_name="main") + fused = sch.fuse(*sch.get_loops(block=b00)) + _, v = sch.split(loop=fused, factors=[None, vector_size]) + sch.vectorize(v) + + with tvm.transform.PassContext(opt_level=3): + hex_lib = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) + host_lib = tvm.build(mod, target=tvm.target.Target("llvm")) + + asm = hex_lib.get_source("asm") + + # Check that 'vmpye' instruction was generated in asm file. + vmpye_regex = re.compile(r"v\d{1,2}.w = vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)") + assert vmpye_regex.search(asm) is not None + + # Check that 'vmpyo' instruction was generated in asm file. + vmpyo_regex = re.compile(r"v\d{1,2}.w \+= vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift") + assert vmpyo_regex.search(asm) is not None + + # Verify accuracy + x_np = ( + np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32") + ) + y_np = ( + np.random.randint(-1000, 1000, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + ) + lsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + rsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32") + b_np = ( + np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32") + ) + np_args = [x_np, y_np, lsh_np, rsh_np, b_np] + hex_args = [ + tvm.runtime.ndarray.array(arg, device=hexagon_session.device, mem_scope="global") + for arg in np_args + ] + host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args] + + hex_rt = hexagon_session.load_module(hex_lib) + hex_rt(*hex_args) + host_lib(*host_args) + + assert np.allclose(hex_args[4].numpy(), host_args[4].numpy()) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 61b1828aad992..f624984481da2 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -16,6 +16,7 @@ # under the License. import os import re +import numpy as np import shutil import tarfile from os import path @@ -29,6 +30,7 @@ import tvm.testing from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.relay.backend import Runtime, Executor +from tvm import relay from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -49,6 +51,355 @@ def test_save_dumps(tmpdir_factory): assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) +def test_save_dump_offloads_ethosu(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + import tflite.Model + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + weight_shape[2] = 3 + weights1 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + weights2 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op1 = tf.nn.conv2d( + op, + filters=weights1, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op2 = tf.nn.conv2d( + op, + filters=weights2, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.math.add(op1, op2) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[1]], [padding_out[2], padding[3]], [0, 0]], + "CONSTANT", + ) + return op + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="ethos-u,cmsis-nn,c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + "cmsis-nn": {"mcpu": "cortex-m55"}, + "ethos-u": { + "accelerator_config": "ethos-u55-256", + }, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 11", + r"ethos-u: 10", + r"generic: 1", + r"", + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'ethos-u <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.11364f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %3 = qnn.conv2d(%2, %v_param_3, -128, 0, 0.11364f, meta[relay.Constant][2], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %4 = nn.bias_add(%3, %v_param_4, axis=3)", + r'ethos-u <- %7 = qnn.requantize(%4, meta[relay.Constant][3], 0, 1.56803f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %5 = qnn.conv2d(%2, %v_param_5, -128, 0, 0.11364f, meta[relay.Constant][4], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %6 = nn.bias_add(%5, %v_param_6, axis=3)", + r'ethos-u <- %8 = qnn.requantize(%6, meta[relay.Constant][5], 0, 1.20538f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.add", + r"ethos-u <- %9 = qnn.add(%7, %8, 1.56803f, -128, 1.20538f, -128, 2.77341f, -128)", + r"generic <- nn.pad(%9, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + +def test_save_dump_offloads_cmsis(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.nn.relu(op) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="cmsis-nn,c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + "cmsis-nn": {"mcpu": "cortex-m55"}, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 7", + r"cmsis-nn: 4", + r"generic: 3", + r"", + r"cmsis-nn <- cmsis-nn.qnn_conv2d", + r'cmsis-nn <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"cmsis-nn <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'cmsis-nn <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.113405f, -128, axis=3, out_dtype="int8")', + r"cmsis-nn <- %3 = clip(%2, a_min=-128f, a_max=127f)", + r"generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + r"generic <- %5 = nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + r"generic <- nn.pad(%5, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + +def test_save_dump_offloads_generic(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 6", + r"generic: 6", + r"", + r'generic <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392156f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"generic <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'generic <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.103975f, -128, axis=3, out_dtype="int8")', + r"generic <- %3 = nn.pad(%2, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + r"generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + r"generic <- nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + # End to end tests for compilation diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 116c023caadbc..34dda6cf6f7b8 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -99,6 +99,14 @@ def get_tvm_output_with_vm( freeze_params=freeze_params, convert_config=convert_config, ) + # handle the bfloat16 so we explicitly allocate + # bfloat16 arrays as input + for i, param in enumerate(mod["main"].params): + if param.type_annotation.dtype == "bfloat16": + input_data[i] = tvm.nd.empty(input_data[i].shape, "bfloat16").copyfrom( + input_data[i] + ) + if validate_structural_equal: with tvm.testing.enable_span_filling(): mod_with_span, _ = relay.frontend.from_onnx( diff --git a/tests/python/relay/opencl_texture/test_injection_texture.py b/tests/python/relay/opencl_texture/test_injection_texture.py new file mode 100644 index 0000000000000..181f0b6ff9098 --- /dev/null +++ b/tests/python/relay/opencl_texture/test_injection_texture.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import pytest +import tvm +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import utils +from utils.adreno_utils import gpu_preprocess, build_run_compare + + +dtype = tvm.testing.parameter("float32") + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nchw4c(remote, target, dtype): + """Verification of the case NCHW->NCHW4c""" + input_shape = (1, 32, 720, 1280) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NCHW", "NCHW4c") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nchw(remote, target, dtype): + """Verification of the case NCHW4c->NCHW""" + input_shape = (1, 36, 1, 1, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NCHW4c", "NCHW") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nhwc4c(remote, target, dtype): + """Verification of the case NHWC->NHWC4c""" + input_shape = (1, 1, 1, 144) + A = relay.var("data", shape=input_shape, dtype=dtype) + lt = relay.layout_transform(A, "NHWC", "NHWC4c") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +@pytest.mark.skip(reason="Skip because GPU in CI doesn't support FP16") +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl -device=adreno") +def test_layout_transform_to_block_nhwc(remote, target, dtype): + """Verification of the case NHWC4c->NHWC""" + input_shape = (1, 80, 80, 36, 4) + A = relay.var("data", shape=input_shape, dtype=dtype) + mean = relay.mean(A, axis=[1, 2], keepdims=True) + cast = relay.cast(mean, "float16") + lt = relay.layout_transform(cast, "NHWC4c", "NHWC") + mod = relay.Function([A], lt) + + build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target) + + +if __name__ == "__main__": + test_layout_transform_to_block_nhwc(None, "opencl -device=adreno", "float16") diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index a3b1cc5e01394..0b6f891cca7dd 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -142,6 +142,42 @@ def after_matmul_vectorize( T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1] +@T.prim_func +def before_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + T.block_attr({"meta_schedule.parallel":64, "meta_schedule.vectorize":128}) + for n, c0, h, w, c1 in T.grid(1, 8, 56, 56, 32): + with T.block("add_compute"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [n, c0, h, w, c1]) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + +@T.prim_func +def after_postproc_add( + lhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + rhs: T.Buffer((1, 8, 56, 56, 32), "uint8"), + add_compute: T.Buffer((1, 8, 56, 56, 32), "uint8"), +) -> None: + with T.block("root"): + for n_c0_h_w_c1_fused_0 in T.parallel(0, 6272): + for n_c0_h_w_c1_fused_1 in T.vectorized(0, 128): + with T.block("add_compute"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(8, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) // 100352) + v2 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 100352 // 1792) + v3 = T.axis.spatial(56, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 1792 // 32) + v4 = T.axis.spatial(32, (n_c0_h_w_c1_fused_0 * 128 + n_c0_h_w_c1_fused_1) % 32) + T.reads(lhs[v0, v1, v2, v3, v4], rhs[v0, v1, v2, v3, v4]) + T.writes(add_compute[v0, v1, v2, v3, v4]) + add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] + rhs[v0, v1, v2, v3, v4] + + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable @@ -161,6 +197,14 @@ def test_vectorize_inner_loop(): tvm.ir.assert_structural_equal(sch.mod["main"], after_matmul_vectorize) +def test_parallel_vectorize_add(): + sch = Schedule(before_postproc_add) + rule = RewriteParallelVectorizeUnroll() + assert rule.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], after_postproc_add) + + if __name__ == "__main__": test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() test_vectorize_inner_loop() + test_parallel_vectorize_add() diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 44f950c82ad37..3190115aa6b25 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -707,7 +707,7 @@ def np_float2tvm_bf16(arr): """Convert a numpy array of float to a TVM array of bf16""" nparr = np_float2np_bf16(arr) - return tvm.nd.empty(nparr.shape, "uint16").copyfrom(nparr) + return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr) def np_bf162np_float(arr): @@ -730,9 +730,9 @@ def dotest(do_vectorize): B = te.placeholder((32,), dtype="bfloat16") d = te.compute((32,), lambda x: A[x] + B[x]) sch = te.create_schedule(d.op) - print(tvm.lower(sch, [A, B, d])) if do_vectorize: sch[d].vectorize(d.op.axis[0]) + module = tvm.build(sch, [A, B, d]) npa = np.random.rand(32).astype("float32") npb = np.random.rand(32).astype("float32") @@ -741,7 +741,7 @@ def dotest(do_vectorize): res = np_bf16_cast_and_cast_back(va + vb) a_ = np_float2tvm_bf16(npa) b_ = np_float2tvm_bf16(npb) - c_ = tvm.nd.empty((32,), "uint16") + c_ = tvm.nd.empty((32,), "bfloat16") module(a_, b_, c_) tvm.testing.assert_allclose(np_bf162np_float(c_.numpy()), res) diff --git a/tests/python/unittest/test_target_codegen_opencl.py b/tests/python/unittest/test_target_codegen_opencl.py index c25b3c2c86eab..bc2d0a84fd9d0 100644 --- a/tests/python/unittest/test_target_codegen_opencl.py +++ b/tests/python/unittest/test_target_codegen_opencl.py @@ -185,8 +185,4 @@ def check_type_casting(ctx, n, dtype): if __name__ == "__main__": - test_opencl_ternary_expression() - test_opencl_inf_nan() - test_opencl_max() - test_opencl_erf() - test_opencl_type_casting() + tvm.testing.main() diff --git a/tests/python/unittest/test_target_texture_codegen_opencl.py b/tests/python/unittest/test_target_texture_codegen_opencl.py index 06876258e5d15..639159c495f02 100644 --- a/tests/python/unittest/test_target_texture_codegen_opencl.py +++ b/tests/python/unittest/test_target_texture_codegen_opencl.py @@ -1397,5 +1397,380 @@ class TestDepthwiseConv2dNCHWcKCRSk(BaseConv2DValidator): test_func = tvm.testing.parameter(depthwise_conv2d_NCHWc_KCRSk_acc32) +def simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, cast_type +): + def _compute(): + p0 = te.placeholder(input_info[1], name="p0", dtype=dtype) + p0_comp = te.compute(input_info[1], lambda *i: p0(*i), name="p0_comp") + if len(output_info[1]) == 4 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w: p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + elif len(output_info[1]) == 5 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w, cb: p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + else: + raise Exception("Impossible case") + dummy_out = te.compute(output_info[1], lambda *i: out(*i), name="dummy_out") + return p0, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out = s[dummy_out].op.input_tensors[0] + p0_comp = s[out].op.input_tensors[0] + s[p0_comp].set_scope(input_info[0]) + bind_data_copy(s[p0_comp]) + s[out].set_scope(output_info[0]) + bind_data_copy(s[out]) + bind_data_copy(s[dummy_out]) + return s + + p0, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p0, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np = np.random.uniform(size=[i for i in input_info[1]]).astype(dtype) + input_tvm = tvm.nd.array(input_np, dev) + c = tvm.nd.empty(output_info[1], dtype, dev) + # Doesn't run OpenCL code for FP16 because GPUs in CI don't support FP16 inference + if cast_type == "float32": + fun(input_tvm, c) + # For output len == 5 it makes no sense to check the accuracy + if cast_type == "float32" and len(output_info[1]) == 4: + np_result = input_np.transpose(0, 2, 3, 1, 4) # NCHW4c -> NHWC4c + np_result = np.squeeze(np_result, axis=3) + np_result = np_result.transpose(0, 3, 1, 2) # NHWC -> NCHW + np.testing.assert_allclose(c.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class TestSimpleTextureToScalarFP16: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]);", + ], + ), + # 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]);" + ], + ), + # 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["global.texture", (1, 1, 40, 40, 4)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));", + "write_imageh(out, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), (convert_half4(v_)));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_simple_texture_to_scalar_fp16( + self, input_info, output_info, find_patterns, dtype, target + ): + simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, "float16" + ) + + +class TestSimpleTextureToScalarFP32: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = ((float*)&v_)[(((int)get_group_id(0)) >> 1)];", + ], + ), + # 2. Buffer (NCHW4c) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))];" + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_simple_texture_to_scalar_fp32( + self, input_info, output_info, find_patterns, dtype, target + ): + simple_texture_to_scalar_common( + target, input_info, output_info, find_patterns, dtype, "float32" + ) + + +def texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, cast_type +): + def _compute(): + p0 = te.placeholder(input_info[1], name="p0", dtype=dtype) + p0_comp = te.compute(input_info[1], lambda *i: p0(*i), name="p0_comp") + if len(output_info[1]) == 4 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w: p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + out2 = te.compute( + output_info[1], + lambda n, c, h, w: out[n][c][h][w] + + p0_comp[n][c // 4][h][w][c % 4].astype(cast_type), + name="out", + ) + elif len(output_info[1]) == 5 and len(input_info[1]) == 5: + out = te.compute( + output_info[1], + lambda n, c, h, w, cb: p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + out2 = te.compute( + output_info[1], + lambda n, c, h, w, cb: out[n][c][h][w][cb] + + p0_comp[n][c][h][w][cb].astype(cast_type), + name="out", + ) + else: + raise Exception("Impossible case") + out_sum = te.compute(output_info[1], lambda *i: out(*i) + out2(*i), name="out_sum") + dummy_out = te.compute(output_info[1], lambda *i: out_sum(*i), name="dummy_out") + return p0, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out_sum = s[dummy_out].op.input_tensors[0] + out, out2 = s[out_sum].op.input_tensors + p0_comp = s[out].op.input_tensors[0] + s[p0_comp].set_scope(input_info[0]) + bind_data_copy(s[p0_comp]) + s[out].set_scope(output_info[0]) + s[out2].set_scope(output_info[0]) + s[out2].compute_inline() + s[out].compute_inline() + s[out_sum].set_scope(output_info[0]) + bind_data_copy(s[out_sum]) + bind_data_copy(s[dummy_out]) + return s + + p0, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p0, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np = np.random.uniform(size=[i for i in input_info[1]]).astype(dtype) + input_tvm = tvm.nd.array(input_np, dev) + c = tvm.nd.empty(output_info[1], dtype, dev) + # Doesn't run OpenCL code for FP16 because GPUs in CI don't support FP16 inference + if cast_type == "float32": + fun(input_tvm, c) + # For output len == 5 it makes no sense to check the accuracy + if cast_type == "float32" and len(output_info[1]) == 4: + np_result = input_np * 3 + np_result = np_result.transpose(0, 2, 3, 1, 4) # NCHW4c -> NHWC4c + np_result = np.squeeze(np_result, axis=3) + np_result = np_result.transpose(0, 3, 1, 2) # NHWC -> NCHW + np.testing.assert_allclose(c.asnumpy(), np_result, rtol=1e-2, atol=1e-2) + + +class TestTextureToScalarReuseSSAFP16: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + (((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)]) + ((half)((float*)&v_)[(((int)get_group_id(0)) >> 1)])));", + ], + ), + # 2. Buffer (NCHW4c) -> Cast(FP16) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + (((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]) + ((half)p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))])));" + ], + ), + # 3. Texture (NCHW4c) -> Cast(FP16) -> Texture (NCHW4c) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["global.texture", (1, 1, 40, 40, 4)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5))));", + "write_imageh(out_sum, (int2)((((((int)get_group_id(0)) * 24) + ((int)get_local_id(0))) % 40), (((((int)get_group_id(0)) * 8) + (((int)get_local_id(0)) >> 3)) / 5)), ((convert_half4(v_)) + ((convert_half4(v_)) + (convert_half4(v_)))));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_texture_to_scalar_reuse_ssa_fp16( + self, input_info, output_info, find_patterns, dtype, target + ): + texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, "float16" + ) + + +class TestTextureToScalarReuseSSAFP32: + # (input [scope, shape], output [scope, shape], [find_patterns]) + input_info, output_info, find_patterns = tvm.testing.parameters( + # 1. Texture (NCHW4c) -> Buffer (NCHW) + ( + ["global.texture", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "float4 v_ = READ_IMAGEF(p0_comp, image_sampler, ((int2)((((int)get_local_id(0)) % 40), (((((int)get_group_id(0)) & 1) * 20) + (((int)get_local_id(0)) / 40)))));", + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + (((float*)&v_)[(((int)get_group_id(0)) >> 1)] + ((float*)&v_)[(((int)get_group_id(0)) >> 1)]));", + ], + ), + # 2. Buffer (NCHW4c) -> Buffer (NCHW) + ( + ["", (1, 1, 40, 40, 4)], + ["", (1, 4, 40, 40)], + [ + "out_sum[((((int)get_group_id(0)) * 800) + ((int)get_local_id(0)))] = (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + (p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))] + p0_comp[((((((int)get_group_id(0)) & 1) * 3200) + (((int)get_local_id(0)) * 4)) + (((int)get_group_id(0)) >> 1))]));" + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_texture_to_scalar_reuse_ssa_fp32( + self, input_info, output_info, find_patterns, dtype, target + ): + texture_to_scalar_reuse_ssa_common( + target, input_info, output_info, find_patterns, dtype, "float32" + ) + + +class TestLocalArrayToTexture: + # 1. conv2d(Texture(NCHW4c), Texture(OIHW4o)) -> local_array[4] -> Texture (NCHW4c) + input_shape1, input_shape2, output_shape, find_patterns = tvm.testing.parameters( + ( + (1, 1, 40, 40, 4), + (2, 4, 3, 3, 4), + (1, 2, 38, 38, 4), + [ + "float out_local[4];", + "float4 v_ = READ_IMAGEF(p1_comp, image_sampler, ((int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), ((((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) % 722) / 19))));", + "float4 v__1 = READ_IMAGEF(p2_comp, image_sampler, ((int2)(rw, ((((((((int)get_group_id(0)) * 32) + (((int)get_local_id(0)) >> 2)) / 361) * 12) + (rcb * 3)) + rh))));", + "out_local[cb_c] = (out_local[cb_c] + (((float*)&v_)[rcb] * ((float*)&v__1)[cb_c]));", + "write_imagef(out, (int2)((((((int)get_group_id(0)) * 14) + ((int)get_local_id(0))) % 38), (((((int)get_group_id(0)) * 64) + (((int)get_local_id(0)) >> 1)) / 19)), vload4(0, out_local + 0));", + ], + ), + ) + dtype = tvm.testing.parameter("float32") + + @tvm.testing.parametrize_targets("opencl") + def test_local_array_to_texture( + self, input_shape1, input_shape2, output_shape, find_patterns, dtype, target + ): + def _compute(): + p1 = te.placeholder(input_shape1, name="p1", dtype=dtype) + p1_comp = te.compute(input_shape1, lambda *i: p1(*i), name="p1_comp") + p2 = te.placeholder(input_shape2, name="p2", dtype=dtype) + p2_comp = te.compute(input_shape2, lambda *i: p2(*i), name="p2_comp") + KH, KW = input_shape2[2], input_shape2[3] + IC, ICB = input_shape1[1], input_shape1[4] + rh = te.reduce_axis((0, KH), name="rh") + rw = te.reduce_axis((0, KW), name="rw") + rc = te.reduce_axis((0, IC), name="rc") + rcb = te.reduce_axis((0, ICB), name="rcb") + out = te.compute( + output_shape, + lambda n, c, h, w, cb: te.sum( + (p1_comp[n, rc, h, w, rcb] * p2_comp[c, rc * ICB + rcb, rh, rw, cb]).astype( + dtype + ), + axis=[rh, rw, rc, rcb], + ), + name="out", + ) + dummy_out = te.compute(output_shape, lambda *i: out(*i), name="dummy_out") + return p1, p2, dummy_out + + def _schedule(dummy_out): + from tvm.topi.adreno.utils import bind_data_copy + + s = te.create_schedule(dummy_out.op) + out = s[dummy_out].op.input_tensors[0] + p1_comp, p2_comp = s[out].op.input_tensors + bind_data_copy(s[p1_comp]) + s[p1_comp].set_scope("global.texture") + bind_data_copy(s[p2_comp]) + s[p2_comp].set_scope("global.texture") + OL = s.cache_write(out, "local") + n, c, h, w, cb = s[out].op.axis + fused = s[out].fuse(n, c, h, w) + bx, tx = s[out].split(fused, 128) + s[out].reorder(bx, tx, cb) + s[out].vectorize(cb) + s[out].set_scope("global.texture") + s[out].bind(bx, te.thread_axis("blockIdx.x")) + s[out].bind(tx, te.thread_axis("threadIdx.x")) + s[OL].compute_at(s[out], tx) + bind_data_copy(s[dummy_out]) + return s + + p1, p2, dummy_out = _compute() + s = _schedule(dummy_out) + + fun = tvm.build(s, [p1, p2, dummy_out], target) + dev = tvm.device(target, 0) + opencl_source = fun.imported_modules[0].get_source() + start_idx = 0 + for pattern in find_patterns: + start_idx = opencl_source.find(pattern, start_idx) + assert start_idx > -1 + + input_np1 = np.random.uniform(size=[i for i in input_shape1]).astype(dtype) + input_np2 = np.random.uniform(size=[i for i in input_shape2]).astype(dtype) + input_tvm1 = tvm.nd.array(input_np1, dev) + input_tvm2 = tvm.nd.array(input_np2, dev) + c = tvm.nd.empty(output_shape, dtype, dev) + fun(input_tvm1, input_tvm2, c) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py index 06f6fe31278dd..489db287f3779 100644 --- a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -77,5 +77,35 @@ def test_flops_with_if(): assert flops == 16 +@T.prim_func +def flops_with_forloop_as_expression(A: T.Buffer(1)): + for i in T.serial(0, 16): + for k in T.serial(0, i): + A[0] = A[0] + 1 + + +@T.prim_func +def flops_override(A: T.Buffer(16, "float32")): + T.func_attr({"estimated_flops": 32}) + for i in range(16): + A[0] = A[0] + 1 + + +def test_estimate_flops_forloop_as_experssion(): + flops = estimate_tir_flops( + IRModule({"main": flops_with_forloop_as_expression.with_attr("estimated_flops", 32)}) + ) + assert flops == 32 + + # test whether the user estimated flop would over ride + flops = estimate_tir_flops(IRModule({"main": flops_override})) + assert flops == 32 + + +def test_exception(): + with pytest.raises(tvm.TVMError): + flops = estimate_tir_flops(IRModule({"main": flops_with_forloop_as_expression})) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 0000000000000..dd61a4d62be17 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index c9a8f70ef7b3b..8de11d8bd519e 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -1049,5 +1049,41 @@ def func(A: T.Buffer(T.int64(58), "int32")): ) +def test_index_map_dtype_legalize_with_constant(): + """Legalization of inverse containing a constant output + + The index map `lambda i,j: [i, j//8, j % 8]` has an inverse `lambda i,j,k: [i, 8*j+k]`. + """ + + @T.prim_func + def func(A: T.Buffer(T.int64(16), "int32")): + for i in T.grid(T.int64(16)): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + sch = tir.Schedule(func) + + # Triggering the error requires an IndexMap that introduces padding + func = lambda i: [ + # And a constant to be one of the output indices. + tir.const(0, i.dtype), + (i + 1) // 8, + (i + 1) % 8, + ] + + # Previously, the legalization was only handled by propagating the + # dtype of the indices to the transformed indices. As a result, + # output indices whose value did not depend on the input index + # would be left with the incorrect dtype. + + # Prior to the bugfix, this resulted in the following error is + # raised from the IterVar constructor. + # + # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : + # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) + sch.transform_layout(block="block", buffer="A", index_map=func, pad_value=0) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_bf16_legalize.py b/tests/python/unittest/test_tir_transform_bf16_legalize.py index 1e3c8061e0294..ababfd489af56 100644 --- a/tests/python/unittest/test_tir_transform_bf16_legalize.py +++ b/tests/python/unittest/test_tir_transform_bf16_legalize.py @@ -15,164 +15,105 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import topi -from tvm import te - - -def lower_stmt(sche, params, passfunc): - func = tvm.driver.build_module.schedule_to_module(sche, params, "main", None)["main"] - func = passfunc()(tvm.IRModule.from_expr(func))["main"] - stmt = func.body - return stmt - - -def test_promote(): - def runpass(op, passfunc): - a = te.placeholder((100,), dtype="bfloat16") - b = te.placeholder((100,), dtype="bfloat16") - c = te.compute((100,), lambda i: op(a[i], b[i])) - s = te.create_schedule(c.op) - return lower_stmt(s, [a, b, c], passfunc) - - def get_promoted(op): - a = te.placeholder((100,), dtype="bfloat16") - b = te.placeholder((100,), dtype="bfloat16") - c = te.compute( - (100,), - lambda i: topi.cast(op(topi.cast(a[i], "float"), topi.cast(b[i], "float")), "bfloat16"), - ) - s = te.create_schedule(c.op) - func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] - return func.body - - def test_promoted(op): - stmt = runpass(op, tvm.tir.transform.BF16Promote) - tvm.ir.assert_structural_equal(stmt, get_promoted(op)) - - test_promoted(topi.add) - test_promoted(topi.subtract) - test_promoted(topi.multiply) - test_promoted(topi.divide) - - -def test_eliminate(): - def to32(v): - return topi.cast(v, "float") - - def to16(v): - return topi.cast(v, "bfloat16") - - def get_eliminated(): - a = te.placeholder((100,), dtype="bfloat16") - b = te.placeholder((100,), dtype="bfloat16") - c = te.compute( - (100,), - lambda i: to16( - topi.add( - to32( - to16( - topi.add( - to32(a[i]), - to32(b[i]), - ) - ) - ), - to32( - to16( - topi.add( - to32(a[i]), - to32(b[i]), - ) - ) - ), - ) - ), - ) - s = te.create_schedule(c.op) - stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination) - return stmt - - def get_target(): - a = te.placeholder((100,), dtype="bfloat16") - b = te.placeholder((100,), dtype="bfloat16") - c = te.compute( - (100,), - lambda i: to16( - topi.add( - topi.add( - to32(a[i]), - to32(b[i]), - ), - topi.add( - to32(a[i]), - to32(b[i]), - ), - ) - ), - ) - s = te.create_schedule(c.op) - func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] - return func.body - - tvm.ir.assert_structural_equal(get_eliminated(), get_target()) - - -def test_legalize(): - def to32(v): - uint32_v = topi.cast(v, "uint32") - uint32_v = tvm.tir.call_intrin( - "uint32", "tir.shift_left", uint32_v, tvm.tir.const(16, "uint32") - ) - return tvm.tir.call_intrin("float32", "tir.reinterpret", uint32_v) - - def to16(v): - uint32_v = tvm.tir.call_intrin("uint32", "tir.reinterpret", v) - rounding_bias = tvm.tir.call_intrin( - "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32") - ) - rounding_bias = tvm.tir.call_intrin( - "uint32", "tir.bitwise_and", rounding_bias, tvm.tir.const(1, "uint32") - ) - rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16") - uint32_v = uint32_v + rounding_bias - uint32_v = tvm.tir.call_intrin( - "uint32", "tir.shift_right", uint32_v, tvm.tir.const(16, "uint32") - ) - return topi.cast(uint32_v, "uint16") - - def check(fcompute_before, fcompute_after): - a = te.placeholder((100,), dtype="bfloat16", name="A") - b = te.placeholder((100,), dtype="bfloat16", name="B") - c = te.compute((100,), fcompute_before(a, b), name="C") - s = te.create_schedule(c.op) - stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16Legalize) - - a = te.placeholder((100,), dtype="uint16", name="A") - b = te.placeholder((100,), dtype="uint16", name="B") - c = te.compute((100,), fcompute_after(a, b), name="C") - s = te.create_schedule(c.op) - func = tvm.driver.build_module.schedule_to_module(s, [a, b, c], "main", None)["main"] - tvm.ir.assert_structural_equal(stmt, func.body) - - def orig1(a, b): - return lambda i: a[i] + b[i] + a[99 - i] + b[99 - i] - - def after1(a, b): - return lambda i: to16(to32(a[i]) + to32(b[i]) + to32(a[99 - i]) + to32(b[99 - i])) - - def orig2(a, b): - return lambda i: a[i] * b[i] + a[99 - i] * b[99 - i] + a[i] - - def after2(a, b): - return lambda i: to16( - to32(a[i]) * to32(b[i]) + to32(a[99 - i]) * to32(b[99 - i]) + to32(a[i]) - ) - - check(orig1, after1) - check(orig2, after2) +import tvm.script +from tvm.script import tir as T + + +def get_before(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main( + Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "bfloat16") + for i in T.grid(100): + C[i] = A[i] + B[i] + D[i] = T.exp(C[i]) + + return Before + + +def u16tof32(v): + uint32_v = v.astype("uint32") + uint32_v = uint32_v << tvm.tir.const(16, "uint32") + return T.reinterpret("float32", uint32_v) + + +def bf16tof32(v): + return u16tof32(T.reinterpret("uint16", v)) + + +def f32tou16(v): + uint32_v = T.reinterpret("uint32", v) + rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) & tvm.tir.const(1, "uint32") + rounding_bias += tvm.tir.const(0x7FFF, "uint32") + uint32_v = uint32_v + rounding_bias + return uint32_v >> tvm.tir.const(16, "uint32") + + +def f32tobf16(v): + uint32_v = f32tou16(v) + return T.reinterpret("bfloat16", uint32_v.astype("uint16")) + + +def get_after_compute_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main( + Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: T.handle("bfloat16") + ): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "bfloat16", data=Aptr) + B = T.decl_buffer((100,), "bfloat16", data=Bptr) + D = T.decl_buffer((100,), "bfloat16", data=Dptr) + C = T.decl_buffer((100,), "float32") + for i in T.grid(100): + C[i] = bf16tof32(A[i]) + bf16tof32(B[i]) + D[i] = f32tobf16(T.exp(C[i])) + + return After + + +def get_after_storage_legalize(): + @tvm.script.ir_module + class After: + @T.prim_func + def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: T.handle("uint16")): + T.func_attr({"global_symbol": "main"}) + A = T.decl_buffer((100,), "uint16", data=Aptr) + B = T.decl_buffer((100,), "uint16", data=Bptr) + D = T.decl_buffer((100,), "uint16", data=Dptr) + C = T.decl_buffer((100,), "float32") + for i in T.grid(100): + C[i] = u16tof32(A[i]) + u16tof32(B[i]) + D[i] = f32tou16(T.exp(C[i])) + + return After + + +def test_bf16_compute_legalize(): + before = get_before() + expected = get_after_compute_legalize() + # run the transform twice to ensure we can afford to deal + # with this repeative optimizations + after = tvm.tir.transform.BF16ComputeLegalize()(before) + after = tvm.tir.transform.BF16ComputeLegalize()(after) + + tvm.ir.assert_structural_equal(after, expected) + + +def test_bf16_storage_legalize(): + before = get_after_compute_legalize() + after = tvm.tir.transform.BF16StorageLegalize()(before) + expected = get_after_storage_legalize() + tvm.ir.assert_structural_equal(after, expected) if __name__ == "__main__": - test_promote() - test_eliminate() - test_legalize() + test_bf16_storage_legalize()