Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dot product support for quantized convolution. #6445

Merged
merged 10 commits into from
Oct 6, 2020
Merged
44 changes: 30 additions & 14 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="conv2d_direct_simd.micro_dev",
)
elif kernel_layout == "HWIO":
is_aarch64 = "aarch64" in str(isa.target)

is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
has_dot_prod = topi.arm_cpu.arm_utils.is_dotprod_available()
if has_dot_prod and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if is_aarch64 and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized.arm_cpu",
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
# Let's use this only for non-aarch64 or non-quantized cases
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.arm_cpu",
)

strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
name="conv2d_nhwc_spatial_pack.arm_cpu",
)
else:
raise RuntimeError(
"Unsupported kernel layout {} for conv2d NHWC".format(kernel_layout)
Expand Down Expand Up @@ -328,11 +337,18 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
data = inputs[0]
strategy = _op.OpStrategy()

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="conv2d_NHWC_quantized_without_transform.arm_cpu",
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
raise RuntimeError(
Expand Down
100 changes: 100 additions & 0 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Arm target utility functions"""

import re
import tvm


def get_arch_version(target_mattr):
"""Parse the LLVM target -mattr, and return
the architecture version in a decimal representation
(e.g., if -mattr=v8.4a, return 8.4)
"""

arch_version = 8.0
m = re.compile(r"\+v(.*)\.(.*)a")
for attr in target_mattr:
match_obj = m.match(attr)
if match_obj:
major = int(match_obj.group(1))
minor = int(match_obj.group(2))
decimal = 10
if minor >= 10:
decimal = 100
arch_version = major + float(minor) / decimal

return arch_version


def is_dotprod_available():
""" Checks whether the hardware has support for fast Int8 arithmetic operations. """
target = tvm.target.Target.current(allow_none=False)
arch_version = get_arch_version(target.mattr)
return arch_version >= 8.4 or ((arch_version in (8.2, 8.3)) and "+dotprod" in target.mattr)


def is_aarch64_arm():
""" Checks whether we are compiling for an AArch64 target. """
target = tvm.target.Target.current(allow_none=False)
return "aarch64" in target.attrs.get("mtriple", "")


def get_tiling_B_interleaved_t(interleave_A):
"""Compute the tiling information for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.

The tiling information is chosen to maximize register usage during the
tile computation.

Please refer to:
- https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product
- Conv2DGemmWeightTransformRel in src/relay/op/nn/convolution.h
In order to have more information

Parameters
----------
interleave_A: bool
determines if A is expected to be interleaved

Returns
----------
tile_rows_B: the output tile rows of B'
tile_cols_B: the output tile columns of B'
"""
if is_dotprod_available():
# The number of tile rows of B' vary depending on the
# strategy:
# * If we are interleaving A, then we select 12 columns from B'(i.e.,
# 12 rows from B).
# * If we are not interleaving A, then we select 16 columns from B'(i.e.,
# 16 rows from B).
tile_rows_B = 12 if interleave_A else 16

# Dot product instruction groups 2 (u)int16x8 vectors in
# groups of 4 and compute the dot product among those groups
# This means that the number of columns in a tile of B' (i.e., the
# rows of the original matrix B) need to be 4.
tile_cols_B = 4
else:
# If dot product is not available, A must be interleaved. In this case
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_rows_B = 4
tile_cols_B = 16

return tile_rows_B, tile_cols_B
105 changes: 75 additions & 30 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,64 @@
from ..nn import conv2d_alter_layout
from ..util import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from .arm_utils import get_tiling_B_interleaved_t

logger = logging.getLogger("topi")


def interleave_transpose_weights(inputs, data, kernel, interleave_A):
"""Transform the weight matrix by reshaping, interleaving and transposing it

Parameters
----------
inputs : tvm.relay.Expr
Grouped input symbols
data :
Input shape and dtype
kernel :
Input shape and dtype
interleave_A: indicates if we expect matrix A to be interleaved

Returns
----------
new_kernel : tvm.te.placeholder
A placeholder with the new shape
new_kernel_expr : tvm.relay.Expr
The relay expression of the weights
"""
assert (
data.dtype == "int8"
and kernel.dtype == "int8"
or data.dtype == "uint8"
and kernel.dtype == "uint8"
)

KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC

# Get tiling information for the interleaved transposed version of B
tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)

pad_K = 0
pad_N = 0

if N % tile_rows_B != 0:
pad_N = tile_rows_B - (N % tile_rows_B)
if K % tile_cols_B != 0:
pad_K = tile_cols_B - (K % tile_cols_B)

N_padded = N + pad_N
K_padded = K + pad_K
new_kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(
inputs[1], tile_rows_B, tile_cols_B
)
new_kernel = te.placeholder(
(N_padded // tile_rows_B, K_padded // tile_cols_B, tile_rows_B, tile_cols_B), kernel.dtype
)
return new_kernel, new_kernel_expr


@conv2d_alter_layout.register(["arm_cpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
Expand Down Expand Up @@ -279,44 +333,35 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
assert (
data.dtype == "int8"
and kernel.dtype == "int8"
or data.dtype == "uint8"
and kernel.dtype == "uint8"
)
if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, IC, OC = get_const_tuple(kernel.shape)
K = KH * KW * IC
N = OC

tile_rows = 4
tile_cols = 16
pad_K = 0
pad_N = 0

if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)
if K % tile_cols != 0:
pad_K = tile_cols - (K % tile_cols)

N_padded = N + pad_N
K_padded = K + pad_K
kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
new_kernel = te.placeholder(
(N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), kernel.dtype
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=True
)

new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)

return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], kernel_expr, **new_attrs
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
new_kernel, new_kernel_expr = interleave_transpose_weights(
inputs, data, kernel, interleave_A=False
)
new_workload = autotvm.task.args_to_workload(
[data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
new_workload_name,
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv2d_gemm_without_weight_transform(
inputs[0], new_kernel_expr, **new_attrs
)

return None
Loading