Skip to content

Commit 02cd271

Browse files
committed
add cuda conv2d strategy
1 parent 43fd6f1 commit 02cd271

File tree

19 files changed

+762
-688
lines changed

19 files changed

+762
-688
lines changed

python/tvm/relay/op/_transform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def compute_argwhere(attrs, inputs, output_type):
8585

8686
_reg.register_schedule("argwhere", strategy.schedule_argwhere)
8787

88-
############################### shape func #################################
88+
#####################
89+
# Shape functions #
90+
#####################
8991

9092
@script
9193
def _arange_shape_func(start, stop, step):

python/tvm/relay/op/nn/_nn.py

Lines changed: 11 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -87,100 +87,8 @@ def compute_sparse_transpose(attrs, inputs, out_type):
8787

8888

8989
# conv2d
90-
def _find_conv2d_op(op):
91-
"""Find the op with conv2d in its tag by traversing."""
92-
if 'conv2d' in op.tag:
93-
return op
94-
for tensor in op.input_tensors:
95-
op_ = _find_conv2d_op(tensor.op)
96-
if op_ is not None:
97-
return op_
98-
return None
99-
100-
# @reg.register_compute("nn.conv2d")
101-
# def compute_conv2d(attrs, inputs, out_type, target):
102-
# """Compute definition of conv2d"""
103-
# padding = get_const_tuple(attrs.padding)
104-
# strides = get_const_tuple(attrs.strides)
105-
# dilation = get_const_tuple(attrs.dilation)
106-
# groups = attrs.groups
107-
# layout = attrs.data_layout
108-
# kernel_layout = attrs.kernel_layout
109-
# out_dtype = attrs.out_dtype
110-
# out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
111-
# else out_dtype)
112-
#
113-
# assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
114-
# (dilation_h, dilation_w) = dilation
115-
# if dilation_h < 1 or dilation_w < 1:
116-
# raise ValueError("dilation should be positive value")
117-
#
118-
# def _get_out_depth():
119-
# weight_shape = get_const_tuple(inputs[1].shape)
120-
# # NHWC layout
121-
# if kernel_layout.startswith("HW"):
122-
# return weight_shape[2] * weight_shape[3]
123-
# # NCHW layout.
124-
# # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
125-
# if len(weight_shape) == 4:
126-
# return weight_shape[0] * weight_shape[1]
127-
# else:
128-
# assert len(weight_shape) == 5
129-
# C, M, _, _, VC = weight_shape
130-
# return C * VC * M
131-
#
132-
# if groups == 1:
133-
# out = topi.nn.conv2d(
134-
# inputs[0], inputs[1], strides, padding,
135-
# dilation, layout, out_dtype)
136-
# elif layout == "NCHW" and _get_out_depth() == groups:
137-
# out = topi.nn.depthwise_conv2d_nchw(
138-
# inputs[0], inputs[1], strides, padding, dilation, out_dtype)
139-
# elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
140-
# out = topi.nn.depthwise_conv2d_nhwc(
141-
# inputs[0], inputs[1], strides, padding, dilation, out_dtype)
142-
# elif layout in ['NCHW', 'NCHW4c']:
143-
# out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
144-
# out_dtype)
145-
# else:
146-
# raise ValueError("not support arbitrary group number for now")
147-
# return [out]
148-
149-
150-
# @reg.register_schedule("nn.conv2d")
151-
# def schedule_conv2d(attrs, outs, target):
152-
# """Schedule definition of conv2d"""
153-
# groups = attrs.groups
154-
# layout = attrs.data_layout
155-
# kernel_layout = attrs.kernel_layout
156-
#
157-
# with target:
158-
# if groups == 1 and layout == "NCHW":
159-
# return topi.generic.schedule_conv2d_nchw(outs)
160-
# elif groups == 1 and layout == "NCHW4c":
161-
# return topi.generic.schedule_conv2d_nchw(outs)
162-
# elif groups == 1 and layout == "NHWC":
163-
# return topi.generic.schedule_conv2d_nhwc(outs)
164-
# elif groups == 1 and layout == "HWCN":
165-
# return topi.generic.schedule_conv2d_hwcn(outs)
166-
# elif groups != 1:
167-
# # collect in_channels to distinguish depthwise and group conv2d
168-
# op = _find_conv2d_op(outs[0].op)
169-
# assert op is not None
170-
#
171-
# is_depthwise = 'depthwise' in op.tag
172-
# if is_depthwise:
173-
# if layout == "NCHW":
174-
# # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
175-
# return topi.generic.schedule_depthwise_conv2d_nchw(outs)
176-
# if layout == "NHWC" and kernel_layout == "HWOI":
177-
# return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
178-
# else:
179-
# if layout in ["NCHW", "NCHW4c"]:
180-
# return topi.generic.schedule_group_conv2d_nchw(outs)
181-
# raise ValueError("No compatible schedule")
182-
18390
reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)
91+
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
18492

18593
@reg.register_alter_op_layout("nn.conv2d")
18694
def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
@@ -207,7 +115,6 @@ def legalize_conv2d(attrs, inputs, types):
207115
"""
208116
return topi.nn.conv2d_legalize(attrs, inputs, types)
209117

210-
211118
@reg.register_convert_op_layout("nn.conv2d")
212119
def convert_conv2d(attrs, inputs, tinfos, desired_layout):
213120
"""Convert Layout pass registration for conv2d op.
@@ -248,8 +155,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
248155
return relay.nn.conv2d(data, weight, **new_attrs)
249156
return None
250157

251-
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
252-
253158

254159
# conv2d_transpose
255160
reg.register_strategy("nn.conv2d_transpose", strategy.conv2d_transpose_strategy)
@@ -421,36 +326,9 @@ def compute_mirror_pad(attrs, inputs, out_dtype, target):
421326
reg.register_strategy_broadcast("nn.mirror_pad")
422327

423328

424-
# winograd related operators
425-
@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
426-
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype):
427-
"""Compute definition of conv2d_winograd_without_weight_transform"""
428-
# pylint: disable=assignment-from-no-return
429-
padding = attrs.get_int_tuple("padding")
430-
strides = attrs.get_int_tuple("strides")
431-
dilation = attrs.get_int_tuple("dilation")
432-
groups = attrs.get_int("groups")
433-
data_layout = attrs.get_str("data_layout")
434-
out_dtype = attrs.get_str("out_dtype")
435-
tile_size = attrs.get_int("tile_size")
436-
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
437-
assert dilation == (1, 1), "Do not support dilate now"
438-
assert groups == 1, "Do not supoort arbitrary group number"
439-
440-
out = topi.nn.conv2d_winograd_without_weight_transform(
441-
inputs[0], inputs[1], strides, padding, dilation, data_layout,
442-
out_dtype, tile_size)
443-
444-
return [out]
445-
446-
447-
# @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
448-
# def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
449-
# """Schedule definition of conv2d_winograd_without_weight_transform"""
450-
# with target:
451-
# return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
452-
453-
329+
# conv2d_winograd related operators
330+
reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
331+
strategy.conv2d_winograd_without_weight_transfrom_strategy)
454332
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
455333
OpPattern.OUT_ELEMWISE_FUSABLE)
456334

@@ -462,14 +340,8 @@ def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
462340
inputs[0], attrs.get_int('tile_size'))
463341
return [out]
464342

465-
466-
# @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
467-
# def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
468-
# """Schedule definition of contrib_conv2d_winograd_weight_transform"""
469-
# with target:
470-
# return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
471-
472-
343+
reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform",
344+
strategy.schedule_conv2d_winograd_weight_transform)
473345
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
474346
OpPattern.OUT_ELEMWISE_FUSABLE)
475347

@@ -535,31 +407,8 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
535407
OpPattern.OUT_ELEMWISE_FUSABLE)
536408

537409
# depthwise_conv2d_NCHWc
538-
@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
539-
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
540-
"""Compute definition of depthwise conv2d NCHWc"""
541-
# pylint: disable=assignment-from-no-return
542-
padding = attrs.get_int_tuple("padding")
543-
strides = attrs.get_int_tuple("strides")
544-
dilation = attrs.get_int_tuple("dilation")
545-
data_layout = attrs.get_str("data_layout")
546-
out_layout = attrs.get_str("out_layout")
547-
out_dtype = attrs.get_str("out_dtype")
548-
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
549-
550-
out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
551-
data_layout, out_layout, out_dtype)
552-
return [out]
553-
554-
555-
# @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
556-
# def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
557-
# """Schedule definition of contrib_conv2d_NCHWc"""
558-
# with target:
559-
# return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
560-
561-
562-
reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc", strategy.depthwise_conv2d_NCHWc_strategy)
410+
reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc",
411+
strategy.depthwise_conv2d_NCHWc_strategy)
563412
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
564413
OpPattern.OUT_ELEMWISE_FUSABLE)
565414

@@ -658,7 +507,9 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
658507
reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)
659508

660509

661-
############################### shape func #################################
510+
#####################
511+
# Shape functions #
512+
#####################
662513

663514
@script
664515
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):

python/tvm/relay/op/strategy/cuda.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,102 @@ def schedule_l2_normalize_cuda(attrs, outs, target):
7777
with target:
7878
return topi.cuda.schedule_l2_normalize(outs)
7979

80+
@conv2d_strategy.register(["cuda", "gpu"])
81+
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
82+
"""conv2d cuda strategy"""
83+
strategy = _op.OpStrategy()
84+
data, kernel = inputs
85+
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
86+
groups = attrs.groups
87+
layout = attrs.data_layout
88+
stride_h, stride_w = attrs.get_int_tuple("strides")
89+
kernel_layout = attrs.kernel_layout
90+
if dilation_h < 1 or dilation_w < 1:
91+
raise ValueError("dilation should be positive value")
92+
93+
if groups == 1:
94+
if layout == "NCHW":
95+
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
96+
assert kernel_layout == "OIHW"
97+
strategy.add_implement(
98+
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
99+
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw))
100+
_, _, kh, kw = get_const_tuple(kernel.shape)
101+
if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1:
102+
strategy.add_implement(
103+
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
104+
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
105+
15)
106+
elif layout == "HWCN":
107+
assert kernel_layout == "HWIO"
108+
strategy.add_implement(
109+
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
110+
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn))
111+
elif layout == "NHWC":
112+
assert kernel_layout == ""
113+
strategy.add_implement(
114+
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
115+
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc))
116+
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
117+
assert kernel_layout == "OIHW4o4i"
118+
strategy.add_implement(
119+
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
120+
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8))
121+
else:
122+
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
123+
# add cudnn implementation
124+
if target.target_name == "cuda" and "cudnn" in target.libs:
125+
if layout in ["NCHW", "NHWC"]:
126+
strategy.add_implement(
127+
wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True),
128+
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), 5)
129+
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
130+
if layout == "NCHW":
131+
assert kernel_layout == "OIHW"
132+
strategy.add_implement(
133+
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
134+
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw))
135+
elif layout == "NHWC":
136+
assert kernel_layout == "HWOI"
137+
strategy.add_implement(
138+
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
139+
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc))
140+
else:
141+
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
142+
else: # group_conv2d
143+
if layout == 'NCHW':
144+
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
145+
assert kernel_layout == "OIHW"
146+
strategy.add_implement(
147+
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
148+
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw))
149+
elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
150+
assert kernel_layout == "OIHW4o4i"
151+
strategy.add_implement(
152+
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
153+
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8))
154+
else:
155+
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
156+
return strategy
157+
158+
@conv2d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
159+
def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
160+
dilation = attrs.get_int_tuple("dilation")
161+
groups = attrs.get_int("groups")
162+
layout = attrs.data_layout
163+
assert dilation == (1, 1), "Do not support dilate now"
164+
assert groups == 1, "Do not supoort arbitrary group number"
165+
strategy = _op.OpStrategy()
166+
if layout == "NCHW":
167+
strategy.add_implement(
168+
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform),
169+
wrap_topi_schedule(
170+
topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform_cuda))
171+
else:
172+
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
173+
format(layout))
174+
return strategy
175+
80176
@deformable_conv2d_strategy.register(["cuda", "gpu"])
81177
def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
82178
"""deformable_conv2d cuda strategy"""
@@ -108,7 +204,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
108204
assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
109205
if layout == "NCDHW":
110206
strategy.add_implement(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
111-
_reg._wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
207+
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
112208
10)
113209
else: # layout == "NDHWC":
114210
strategy.add_implement(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),

0 commit comments

Comments
 (0)