Skip to content

Commit e5f278e

Browse files
committed
[Relay][Legalize][ARM_CPU] Handling NHWC layout for arm_cpu.
1 parent 3ac27fc commit e5f278e

File tree

4 files changed

+100
-2
lines changed

4 files changed

+100
-2
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,11 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
204204
from ... import op
205205
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
206206

207-
# A placeholder to have at least one invocation of register legalize to register FTVMLegalize.
208207
@reg.register_legalize("nn.conv2d")
209208
def legalize_conv2d(attrs, inputs, arg_dtypes):
210-
return None
209+
"""Legalize conv2d"""
210+
from ... import op
211+
return topi.nn.conv2d_legalize(attrs, inputs, arg_dtypes, op)
211212

212213
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
213214

tests/python/relay/test_pass_legalize.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Test legalize pass"""
18+
import numpy as np
1819
import tvm
1920

2021
from tvm import relay
22+
from tvm.contrib import graph_runtime
2123
from tvm.relay.op import register_legalize
2224
from tvm.relay import transform, analysis
2325

@@ -123,8 +125,52 @@ def expected():
123125

124126
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
125127

128+
def test_legalize_arm_layout_functional():
129+
"""Test if the legalized conversion yields same result as original"""
130+
def get_output(func, data_val, parameters):
131+
with relay.build_config(opt_level=0):
132+
graph, lib, params = relay.build(func, target='llvm', params=parameters)
133+
m = graph_runtime.create(graph, lib, tvm.cpu())
134+
m.set_input("data", data_val)
135+
m.set_input(**params)
136+
m.run()
137+
out = m.get_output(0, tvm.nd.empty((1, 224, 224, 32), 'float32')).asnumpy()
138+
return out
139+
140+
def before():
141+
n, ic, ih, iw, oc, kh, kw = 1, 16, 224, 224, 32, 3, 3
142+
data = relay.var("data", relay.TensorType((n, ih, iw, ic), 'float32'))
143+
kernel = relay.var("kernel", relay.TensorType((kh, kw, ic, oc), 'float32'))
144+
y = relay.nn.conv2d(data, kernel,
145+
kernel_size=(kh, kw),
146+
channels=oc,
147+
padding=(1, 1),
148+
dilation=(1, 1),
149+
data_layout='NHWC',
150+
kernel_layout='HWIO',
151+
out_dtype='float32')
152+
func = relay.Function([data, kernel], y)
153+
return func
154+
155+
@register_legalize("nn.conv2d", level=101)
156+
def legalize_conv2d(attrs, inputs, arg_types):
157+
from topi.arm_cpu.conv2d import _conv2d_legalize
158+
return _conv2d_legalize(attrs, inputs, arg_types, tvm.relay.op)
159+
160+
a = before()
161+
b = run_opt_pass(a, transform.Legalize())
162+
assert b.astext().count('transpose') == 3
163+
164+
wdata = np.random.rand(3, 3, 16, 32) * 10
165+
parameters = {"kernel": tvm.nd.array(wdata.astype('float32'))}
166+
data_val = np.random.rand(1, 224, 224, 16).astype('float32')
167+
ref_out = get_output(a, data_val, parameters)
168+
legalized_out = get_output(b, data_val, parameters)
169+
np.testing.assert_allclose(ref_out, legalized_out, rtol=0.01)
170+
126171

127172
if __name__ == "__main__":
128173
test_legalize()
129174
test_legalize_none()
130175
test_legalize_multi_input()
176+
test_legalize_arm_layout_functional()

topi/python/topi/arm_cpu/conv2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
conv2d_winograd_without_weight_transform, \
3232
conv2d_winograd_nnpack_without_weight_transform, \
3333
depthwise_conv2d_nchw
34+
from ..nn import conv2d_legalize
3435
from ..nn.util import get_const_int, get_pad_tuple
3536
from ..nn.winograd_util import winograd_transform_matrices
3637

@@ -783,3 +784,31 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
783784
# currently we only have contrib_spatial_pack and direct template
784785
# add more schedule templates.
785786
return None
787+
788+
@conv2d_legalize.register("arm_cpu")
789+
def _conv2d_legalize(attrs, inputs, arg_types, F):
790+
if F.__name__ != 'tvm.relay.op':
791+
return None
792+
if attrs['data_layout'] == 'NHWC':
793+
data, kernel = inputs
794+
if attrs['kernel_layout'] == 'HWIO':
795+
# Handle HWIO layout. This is common in TF graph.
796+
kernel = F.transpose(kernel, axes=(3, 2, 0, 1))
797+
elif attrs['kernel_layout'] == 'HWOI':
798+
# Handle HWOI layout. This is common in TF depthwise conv2d graph.
799+
kernel = F.transpose(kernel, axes=(2, 3, 0, 1))
800+
elif attrs['kernel_layout'] != 'OIHW':
801+
return None
802+
803+
# Set new attrs for the tranposed conv.
804+
new_attrs = {k: attrs[k] for k in attrs.keys()}
805+
new_attrs['data_layout'] = 'NCHW'
806+
new_attrs['kernel_layout'] = 'OIHW'
807+
808+
# Convert from NHWC to NCHW.
809+
data = F.transpose(data, axes=(0, 3, 1, 2))
810+
conv = F.nn.conv2d(data, kernel, **new_attrs)
811+
# Convert back to original NHWC layout.
812+
out = F.transpose(conv, axes=(0, 2, 3, 1))
813+
return out
814+
return None

topi/python/topi/nn/conv2d.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,28 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
7171
raise ValueError("not support this layout {} yet".format(layout))
7272

7373

74+
@tvm.target.generic_func
75+
def conv2d_legalize(attrs, inputs, arg_dtypes, F):
76+
"""Legalizes Conv2D op.
77+
Parameters
78+
----------
79+
attrs : nnvm.top.AttrDict or tvm.attrs.Attrs
80+
Attributes of current convolution
81+
inputs : list of tvm.relay.Expr
82+
The args of the Relay expr to be legalized.
83+
arg_dtypes : list of types
84+
List of types of input arguments
85+
F: symbol
86+
The context, can be either nnvm.sym or relay.op
87+
Note
88+
----
89+
Unlike other TOPI functions, this function operates on both graph level and operator level,
90+
so we have to pass 'F' to make it support our two versions of graph IR, NNVM and Relay.
91+
"""
92+
# not to change by default
93+
return None
94+
95+
7496
@tvm.target.generic_func
7597
def conv2d_alter_layout(attrs, inputs, tinfos, F):
7698
"""Change Conv2D layout.

0 commit comments

Comments
 (0)