Skip to content

Commit 3734d5f

Browse files
authored
[CUDA][PASS]Legalize tensorcore (apache#7147)
* add pad_to_tensorcore & legalize for dense/bmm/conv2d * fix pad & slice * fix comments * fix comments * resolve conflict * resolve conflict * support only fp16 * add tests/python/relay/test_pass_legalize_tensorcore.py * add tests for legalize tensorcore * fix pylint * fix pylint * code format * use_gpu test only; fix conv2d_alter_op * fix tests params * revert transform fix
1 parent 44a071a commit 3734d5f

File tree

7 files changed

+582
-0
lines changed

7 files changed

+582
-0
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@
5252
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
5353

5454

55+
@reg.register_legalize("nn.dense")
56+
def legalize_dense(attrs, inputs, types):
57+
"""Legalize dense op.
58+
59+
Parameters
60+
----------
61+
attrs : tvm.ir.Attrs
62+
Attributes of current convolution
63+
inputs : list of tvm.relay.Expr
64+
The args of the Relay expr to be legalized
65+
types : list of types
66+
List of input and output types
67+
68+
Returns
69+
-------
70+
result : tvm.relay.Expr
71+
The legalized expr
72+
"""
73+
return topi.nn.dense_legalize(attrs, inputs, types)
74+
75+
5576
# dense
5677
reg.register_strategy("nn.dense", strategy.dense_strategy)
5778
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -67,6 +88,27 @@ def compute_fifo_buffer(attrs, inputs, out_type):
6788
reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
6889

6990

91+
@reg.register_legalize("nn.batch_matmul")
92+
def legalize_batch_matmul(attrs, inputs, types):
93+
"""Legalize batch_matmul op.
94+
95+
Parameters
96+
----------
97+
attrs : tvm.ir.Attrs
98+
Attributes of current convolution
99+
inputs : list of tvm.relay.Expr
100+
The args of the Relay expr to be legalized
101+
types : list of types
102+
List of input and output types
103+
104+
Returns
105+
-------
106+
result : tvm.relay.Expr
107+
The legalized expr
108+
"""
109+
return topi.nn.batch_matmul_legalize(attrs, inputs, types)
110+
111+
70112
# batch_matmul
71113
reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
72114
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,6 @@
5555
from .conv2d_hwnc_tensorcore import *
5656
from .correlation import *
5757
from .sparse import *
58+
from . import tensorcore_alter_op
5859
from .argwhere import *
5960
from .scan import *

python/tvm/topi/cuda/conv2d_alter_op.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from .. import nn
2525
from ..utils import get_const_tuple
2626
from .conv2d_winograd import _infer_tile_size
27+
from .tensorcore_alter_op import pad_to_tensorcore
2728
from ..nn import conv2d_legalize
2829

30+
2931
logger = logging.getLogger("topi")
3032

3133

@@ -345,4 +347,50 @@ def _conv2d_legalize(attrs, inputs, arg_types):
345347
else:
346348
out = relay.nn.conv2d(data, kernel, **new_attrs)
347349
return out
350+
elif data_dtype in ["float16"]: # todo: support int8/int4
351+
if data_layout == "NHWC" and kernel_layout == "HWIO":
352+
batch = data_tensor.shape[0].value
353+
in_channel = data_tensor.shape[3].value
354+
out_channel = kernel_tensor.shape[3].value
355+
356+
if (
357+
(batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
358+
or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
359+
or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
360+
):
361+
# no need to pad
362+
return None
363+
364+
(db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)
365+
366+
if extra_flops > 2:
367+
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
368+
return None
369+
370+
logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
371+
372+
# Pad batch size
373+
if db != 0:
374+
data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
375+
376+
# Pad input channel
377+
if di != 0:
378+
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
379+
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
380+
381+
# Pad output channel
382+
if do != 0:
383+
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
384+
385+
if do != 0:
386+
new_out_channel = out_channel + do
387+
new_attrs["channels"] = new_out_channel
388+
389+
out = relay.nn.conv2d(data, kernel, **new_attrs)
390+
391+
if db != 0 or do != 0:
392+
original_out_shape = [x.value for x in output_tensor.shape]
393+
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
394+
395+
return out
348396
return None
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,unused-variable,unused-argument
18+
"""Tensorcore alter op and legalize functions for cuda backend"""
19+
20+
import logging
21+
import math
22+
from tvm import relay
23+
24+
from .. import nn
25+
26+
logger = logging.getLogger("topi")
27+
28+
29+
@nn.batch_matmul_legalize.register("cuda")
30+
def _batch_matmul_legalize(attrs, inputs, arg_types):
31+
"""Legalizes batch_matmul op.
32+
33+
Parameters
34+
----------
35+
attrs : tvm.ir.Attrs
36+
Attributes of current convolution
37+
inputs : list of tvm.relay.Expr
38+
The args of the Relay expr to be legalized
39+
arg_types : list of types
40+
List of input and output types
41+
42+
Returns
43+
-------
44+
result : tvm.relay.Expr
45+
The legalized expr
46+
"""
47+
# Collect the input tensors.
48+
x_tensor, y_tensor = arg_types[0], arg_types[1]
49+
dtype = x_tensor.dtype
50+
51+
# Collect the output tensor.
52+
output_tensor = arg_types[2]
53+
54+
# Collect the input exprs.
55+
x, y = inputs
56+
57+
# Pad input and output channels to use tensorcore schedule.
58+
if dtype in ["float16"]: # todo: support int8/int4
59+
B, M, K = x_tensor.shape
60+
B, N, K = y_tensor.shape
61+
M = M.value
62+
K = K.value
63+
N = N.value
64+
65+
# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
66+
if (
67+
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
68+
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
69+
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
70+
):
71+
# no need to pad
72+
return None
73+
74+
(dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)
75+
76+
if extra_flops > 2:
77+
logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
78+
return None
79+
80+
logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
81+
if dm or dk:
82+
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
83+
else:
84+
x_ = x
85+
if dn or dk:
86+
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk)))
87+
else:
88+
y_ = y
89+
out_ = relay.nn.batch_matmul(x_, y_)
90+
if dm or dn:
91+
original_out_shape = [x.value for x in output_tensor.shape]
92+
out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape)
93+
else:
94+
out = out_
95+
return out
96+
return None
97+
98+
99+
@nn.dense_legalize.register("cuda")
100+
def _dense_legalize(attrs, inputs, arg_types):
101+
"""Legalizes dense op.
102+
103+
Parameters
104+
----------
105+
attrs : tvm.ir.Attrs
106+
Attributes of current convolution
107+
inputs : list of tvm.relay.Expr
108+
The args of the Relay expr to be legalized
109+
types : list of types
110+
List of input and output types
111+
112+
Returns
113+
-------
114+
result : tvm.relay.Expr
115+
The legalized expr
116+
"""
117+
# Collect the input tensors.
118+
x_tensor, y_tensor = arg_types[0], arg_types[1]
119+
dtype = x_tensor.dtype
120+
121+
# Collect the output tensor.
122+
output_tensor = arg_types[2]
123+
124+
# Collect the input exprs.
125+
x, y = inputs
126+
127+
# Pad input and output channels to use tensorcore schedule.
128+
if dtype in ["float16"]: # todo: support int8/int4
129+
M, K = x_tensor.shape
130+
N, K = y_tensor.shape
131+
try:
132+
M = M.value
133+
K = K.value
134+
N = N.value
135+
except AttributeError:
136+
# todo: deal with unfixed shape when compiling wdl model
137+
return None
138+
139+
# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
140+
if (
141+
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
142+
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
143+
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
144+
):
145+
# no need to pad
146+
return None
147+
148+
(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)
149+
150+
if extra_flops_ratio > 2:
151+
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
152+
return None
153+
154+
logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)
155+
156+
if dm or dk:
157+
x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
158+
else:
159+
x_ = x
160+
if dn or dk:
161+
y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk)))
162+
else:
163+
y_ = y
164+
out_ = relay.nn.dense(x_, y_)
165+
if dm or dn:
166+
original_out_shape = [x.value for x in output_tensor.shape]
167+
out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape)
168+
else:
169+
out = out_
170+
return out
171+
return None
172+
173+
174+
def pad_to_tensorcore(M, K, N):
175+
"""pad shape to enable tensorcore"""
176+
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
177+
178+
flops = M * K * N
179+
extra_flops = math.inf
180+
best_pad = (0, 0, 0)
181+
for padding in candidates:
182+
dm, dk, dn = _pad_to(M, K, N, padding)
183+
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
184+
# print(dm, dk, dn, e, flops)
185+
if e < extra_flops:
186+
extra_flops = e
187+
best_pad = (dm, dk, dn)
188+
return best_pad, extra_flops / flops
189+
190+
191+
def _pad_to(M, K, N, PADDING):
192+
dm, dk, dn = 0, 0, 0
193+
194+
if M % PADDING[0] != 0:
195+
M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0]
196+
dm = M_ - M
197+
if K % PADDING[1] != 0:
198+
K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1]
199+
dk = K_ - K
200+
if N % PADDING[2] != 0:
201+
N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2]
202+
dn = N_ - N
203+
204+
return dm, dk, dn

python/tvm/topi/nn/batch_matmul.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""Batch matrix multiplication"""
1818
# pylint: disable=invalid-name
19+
import tvm
1920
from tvm import te, auto_scheduler
2021
from ..utils import get_const_tuple
2122

@@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
7778
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
7879

7980
return output
81+
82+
83+
@tvm.target.generic_func
84+
def batch_matmul_legalize(attrs, inputs, types):
85+
"""Legalizes batch_matmul op.
86+
87+
Parameters
88+
----------
89+
attrs : tvm.ir.Attrs
90+
Attributes of current batch_matmul
91+
inputs : list of tvm.relay.Expr
92+
The args of the Relay expr to be legalized
93+
types : list of types
94+
List of input and output types
95+
96+
Returns
97+
-------
98+
result : tvm.relay.Expr
99+
The legalized expr
100+
"""
101+
# not to change by default
102+
# pylint: disable=unused-argument
103+
return None

python/tvm/topi/nn/dense.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""TVM operator fully connected compute."""
18+
import tvm
1819
from tvm import te, auto_scheduler
1920
from .. import tag
2021

@@ -80,3 +81,26 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
8081
matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)
8182

8283
return matmul
84+
85+
86+
@tvm.target.generic_func
87+
def dense_legalize(attrs, inputs, types):
88+
"""Legalizes dense op.
89+
90+
Parameters
91+
----------
92+
attrs : tvm.ir.Attrs
93+
Attributes of current dense
94+
inputs : list of tvm.relay.Expr
95+
The args of the Relay expr to be legalized
96+
types : list of types
97+
List of input and output types
98+
99+
Returns
100+
-------
101+
result : tvm.relay.Expr
102+
The legalized expr
103+
"""
104+
# not to change by default
105+
# pylint: disable=unused-argument
106+
return None

0 commit comments

Comments
 (0)