Skip to content

Commit 19f105f

Browse files
anijain2305kevinthesun
authored andcommitted
[TOPI][x86] Legalize - Support int8xint8 convolution to use VNNI instructions. (apache#4196)
1 parent 1df6c67 commit 19f105f

File tree

2 files changed

+97
-40
lines changed

2 files changed

+97
-40
lines changed

tests/python/relay/test_op_level2.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,11 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
546546

547547
n, h, w, ch, cw = 1, 64, 64, 3, 3
548548
if data_layout == 'NCHW':
549-
x = relay.var("x", relay.TensorType((n, ic, h, w), input_dtype))
549+
data_shape = (n, ic, h, w)
550+
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
550551
elif data_layout == 'NHWC':
551-
x = relay.var("x", relay.TensorType((n, h, w, ic), input_dtype))
552+
data_shape = (n, h, w, ic)
553+
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
552554
else:
553555
raise ValueError('Not supported')
554556

@@ -559,20 +561,22 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
559561
else:
560562
raise ValueError('Not supported')
561563

562-
w = relay.var("w", relay.TensorType(kernel_shape, weight_dtype))
563-
y = relay.nn.conv2d(x, w,
564+
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
565+
y = relay.nn.conv2d(x, weight,
564566
kernel_size=(ch, cw),
565567
channels=oc,
566568
padding=(1, 1),
567569
dilation=(1, 1),
568570
data_layout=data_layout,
569571
kernel_layout=kernel_layout,
570572
out_dtype=output_dtype)
571-
func = relay.Function([x, w], y)
573+
func = relay.Function([x, weight], y)
572574
wdata = np.random.rand(*kernel_shape) * 10
573-
parameters = {"w": tvm.nd.array(wdata.astype(weight_dtype))}
575+
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
576+
574577
with relay.build_config(opt_level=3):
575578
graph, lib, params = relay.build(func, target, params=parameters)
579+
576580
assembly = lib.get_source("asm")
577581
return assembly
578582

@@ -589,58 +593,63 @@ def _has_fast_int8_instructions(asm, target):
589593
llvm_version = tvm.codegen.llvm_version_major()
590594
for target in targets:
591595
if llvm_version >= 8:
592-
fast_int8_dtypes = ('uint8', 'int8', 'int32')
596+
dtypes = ('uint8', 'int8', 'int32')
593597
# Sweep the input channels to check int8 robustness
594598
# Input channels should be a multiple of 4 internally.
595599
for ic in [1, 4, 6]:
596-
asm = _compile(ic=ic, oc=32, target=target, data_layout="NCHW",
600+
asm = _compile(ic=ic, oc=16, target=target, data_layout="NCHW",
597601
kernel_layout='OIHW',
598-
dtypes=fast_int8_dtypes)
602+
dtypes=dtypes)
599603
assert _has_fast_int8_instructions(asm, target)
600604

601605
for ic in [1, 4, 6]:
602-
asm = _compile(ic=ic, oc=32, target=target, data_layout="NHWC",
606+
asm = _compile(ic=ic, oc=16, target=target, data_layout="NHWC",
603607
kernel_layout='HWIO',
604-
dtypes=fast_int8_dtypes)
608+
dtypes=dtypes)
605609
assert _has_fast_int8_instructions(asm, target)
606610

607-
608611
# Sweep the output channels to check int8 robustness
609612
# Output channels should be a multiple of 16 internally.
610613
for oc in [4, 16, 20]:
611-
asm = _compile(ic=16, oc=oc, target=target, data_layout="NCHW",
614+
asm = _compile(ic=8, oc=oc, target=target, data_layout="NCHW",
612615
kernel_layout='OIHW',
613-
dtypes=fast_int8_dtypes)
616+
dtypes=dtypes)
614617
assert _has_fast_int8_instructions(asm, target)
615618

616619
for oc in [4, 16, 20]:
617-
asm = _compile(ic=16, oc=oc, target=target, data_layout="NHWC",
620+
asm = _compile(ic=8, oc=oc, target=target, data_layout="NHWC",
618621
kernel_layout='HWIO',
619-
dtypes=fast_int8_dtypes)
622+
dtypes=dtypes)
620623
assert _has_fast_int8_instructions(asm, target)
621624

622625
# Check that both non-divisible oc and ic work
623626
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
624-
dtypes=fast_int8_dtypes)
627+
dtypes=dtypes)
625628
assert _has_fast_int8_instructions(asm, target)
626629

627630
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
628-
dtypes=fast_int8_dtypes)
631+
dtypes=dtypes)
629632
assert _has_fast_int8_instructions(asm, target)
630633

631-
# Ensure that code is generated when datatypes are not HW supported.
632-
dtypes = ('int8', 'int8', 'int32')
633-
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
634+
# Check that int8 x int8 goes through legalization so that fast instructions can be picked up.
635+
for target in targets:
636+
if llvm_version >= 8:
637+
dtypes = (('int8', 'int8', 'int32'))
638+
# Check that both non-divisible oc and ic work
639+
asm = _compile(ic=17, oc=29, target=target, data_layout="NCHW", kernel_layout='OIHW',
634640
dtypes=dtypes)
635-
# Check that intrinisic is not present in the assembly.
636-
assert not _has_fast_int8_instructions(asm, target)
641+
assert _has_fast_int8_instructions(asm, target)
637642

638-
# Ensure that code is generated when datatypes are not HW supported.
639-
dtypes = ('uint8', 'uint8', 'int32')
640-
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
643+
asm = _compile(ic=17, oc=29, target=target, data_layout="NHWC", kernel_layout='HWIO',
641644
dtypes=dtypes)
642-
# Check that intrinisic is not present in the assembly.
643-
assert not _has_fast_int8_instructions(asm, target)
645+
assert _has_fast_int8_instructions(asm, target)
646+
647+
# Ensure that code is generated when datatypes are not HW supported.
648+
dtypes = ('uint8', 'uint8', 'int32')
649+
asm = _compile(ic=16, oc=32, target=target, data_layout="NHWC", kernel_layout='HWIO',
650+
dtypes=dtypes)
651+
# Check that intrinisic is not present in the assembly.
652+
assert not _has_fast_int8_instructions(asm, target)
644653

645654
# Check that a vectorized instruction is generated for older Intel
646655
# generations, because we default to NCHWc layout.

topi/python/topi/x86/conv2d_alter_op.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,24 +198,72 @@ def _conv2d_legalize(attrs, inputs, arg_types):
198198
The legalized expr
199199
"""
200200

201+
# Dilation not supported yet. Return None if dilation is not (1, 1)
202+
dilation = attrs.get_int_tuple("dilation")
203+
if not (dilation[0] == 1 and dilation[1] == 1):
204+
return None
205+
201206
# Collect the input tensors.
202207
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
208+
data_dtype = data_tensor.dtype
209+
kernel_dtype = kernel_tensor.dtype
203210

204211
# Collect the output tensor.
205212
output_tensor = arg_types[2]
206213

214+
# Collect the input exprs.
215+
data, kernel = inputs
216+
217+
# Get the conv attrs
218+
new_attrs = {k: attrs[k] for k in attrs.keys()}
219+
220+
is_int8_inputs = False
221+
# If both the inputs are int8, we can add 128 to make the input dtype uint8, and then adjust the
222+
# output. This will help picking up Intel VNNI instructions.
223+
# Original --> C = A (conv) B
224+
# A and B are int8
225+
# C = (A + 128 - 128) (conv) B
226+
# C = (A' conv B) - 128 (conv) B
227+
# where A' = A + 128
228+
# and 128 (conv) B is basically a reduce on CRS axis for weights.
229+
if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
230+
is_int8_inputs = True
231+
padding = attrs.get_int_tuple("padding")
232+
233+
if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
234+
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
235+
pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0))
236+
elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
237+
pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1]))
238+
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
239+
adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
240+
else:
241+
return None
242+
243+
data = relay.cast(data, 'int32')
244+
data = relay.add(data, relay.const(128, 'int32'))
245+
data = relay.cast(data, 'uint8')
246+
247+
# Do external padding as pad value has to be 128.
248+
if not (padding[0] == 0 and padding[1] == 0):
249+
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
250+
new_attrs['padding'] = (0, 0)
251+
252+
# The data type is now shifted to uint8
253+
data_dtype = 'uint8'
254+
255+
# Multiply 128 to adjust shift.
256+
adjust_shift = relay.multiply(adjust_shift, relay.const(128, 'int32'))
257+
207258
# Legalize if the datatypes are suitable for fast Int8 instructions. Int8 instructions require
208259
# input channel to be a multiple of 4 and output channels to be a multiple of 16. For input
209260
# channels, we pad both the inputs and weights input channels. For output channels, we pad the
210261
# weight and stride_slice the output.
211-
if _is_int8_hw_support(data_tensor.dtype, kernel_tensor.dtype):
262+
if _is_int8_hw_support(data_dtype, kernel_dtype):
212263
# Flags to remember if the expr is modified
213264
ic_modified = False
214265
oc_modified = False
215266

216-
# Collect the input exprs.
217-
data, kernel = inputs
218-
219267
# Find the value of input and output channel.
220268
in_channel = -1
221269
out_channel = -1
@@ -256,16 +304,16 @@ def _conv2d_legalize(attrs, inputs, arg_types):
256304
else:
257305
return None
258306

259-
if not (ic_modified or oc_modified):
260-
return None
261-
262-
if ic_modified and not oc_modified:
263-
return relay.nn.conv2d(data, kernel, **attrs)
264-
265307
if oc_modified:
266-
new_attrs = {k: attrs[k] for k in attrs.keys()}
267308
new_attrs['channels'] = new_out_channel
268309
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
269310
original_out_shape = [x.value for x in output_tensor.shape]
270-
return relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
311+
out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape)
312+
else:
313+
out = relay.nn.conv2d(data, kernel, **new_attrs)
314+
315+
if is_int8_inputs:
316+
out = relay.subtract(out, adjust_shift)
317+
318+
return out
271319
return None

0 commit comments

Comments
 (0)