@@ -87,100 +87,8 @@ def compute_sparse_transpose(attrs, inputs, out_type):
8787
8888
8989# conv2d
90- def _find_conv2d_op (op ):
91- """Find the op with conv2d in its tag by traversing."""
92- if 'conv2d' in op .tag :
93- return op
94- for tensor in op .input_tensors :
95- op_ = _find_conv2d_op (tensor .op )
96- if op_ is not None :
97- return op_
98- return None
99-
100- # @reg.register_compute("nn.conv2d")
101- # def compute_conv2d(attrs, inputs, out_type, target):
102- # """Compute definition of conv2d"""
103- # padding = get_const_tuple(attrs.padding)
104- # strides = get_const_tuple(attrs.strides)
105- # dilation = get_const_tuple(attrs.dilation)
106- # groups = attrs.groups
107- # layout = attrs.data_layout
108- # kernel_layout = attrs.kernel_layout
109- # out_dtype = attrs.out_dtype
110- # out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
111- # else out_dtype)
112- #
113- # assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
114- # (dilation_h, dilation_w) = dilation
115- # if dilation_h < 1 or dilation_w < 1:
116- # raise ValueError("dilation should be positive value")
117- #
118- # def _get_out_depth():
119- # weight_shape = get_const_tuple(inputs[1].shape)
120- # # NHWC layout
121- # if kernel_layout.startswith("HW"):
122- # return weight_shape[2] * weight_shape[3]
123- # # NCHW layout.
124- # # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
125- # if len(weight_shape) == 4:
126- # return weight_shape[0] * weight_shape[1]
127- # else:
128- # assert len(weight_shape) == 5
129- # C, M, _, _, VC = weight_shape
130- # return C * VC * M
131- #
132- # if groups == 1:
133- # out = topi.nn.conv2d(
134- # inputs[0], inputs[1], strides, padding,
135- # dilation, layout, out_dtype)
136- # elif layout == "NCHW" and _get_out_depth() == groups:
137- # out = topi.nn.depthwise_conv2d_nchw(
138- # inputs[0], inputs[1], strides, padding, dilation, out_dtype)
139- # elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
140- # out = topi.nn.depthwise_conv2d_nhwc(
141- # inputs[0], inputs[1], strides, padding, dilation, out_dtype)
142- # elif layout in ['NCHW', 'NCHW4c']:
143- # out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
144- # out_dtype)
145- # else:
146- # raise ValueError("not support arbitrary group number for now")
147- # return [out]
148-
149-
150- # @reg.register_schedule("nn.conv2d")
151- # def schedule_conv2d(attrs, outs, target):
152- # """Schedule definition of conv2d"""
153- # groups = attrs.groups
154- # layout = attrs.data_layout
155- # kernel_layout = attrs.kernel_layout
156- #
157- # with target:
158- # if groups == 1 and layout == "NCHW":
159- # return topi.generic.schedule_conv2d_nchw(outs)
160- # elif groups == 1 and layout == "NCHW4c":
161- # return topi.generic.schedule_conv2d_nchw(outs)
162- # elif groups == 1 and layout == "NHWC":
163- # return topi.generic.schedule_conv2d_nhwc(outs)
164- # elif groups == 1 and layout == "HWCN":
165- # return topi.generic.schedule_conv2d_hwcn(outs)
166- # elif groups != 1:
167- # # collect in_channels to distinguish depthwise and group conv2d
168- # op = _find_conv2d_op(outs[0].op)
169- # assert op is not None
170- #
171- # is_depthwise = 'depthwise' in op.tag
172- # if is_depthwise:
173- # if layout == "NCHW":
174- # # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
175- # return topi.generic.schedule_depthwise_conv2d_nchw(outs)
176- # if layout == "NHWC" and kernel_layout == "HWOI":
177- # return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
178- # else:
179- # if layout in ["NCHW", "NCHW4c"]:
180- # return topi.generic.schedule_group_conv2d_nchw(outs)
181- # raise ValueError("No compatible schedule")
182-
18390reg .register_strategy ("nn.conv2d" , strategy .conv2d_strategy )
91+ reg .register_pattern ("nn.conv2d" , OpPattern .OUT_ELEMWISE_FUSABLE )
18492
18593@reg .register_alter_op_layout ("nn.conv2d" )
18694def alter_op_layout_conv2d (attrs , inputs , tinfos , out_type ):
@@ -207,7 +115,6 @@ def legalize_conv2d(attrs, inputs, types):
207115 """
208116 return topi .nn .conv2d_legalize (attrs , inputs , types )
209117
210-
211118@reg .register_convert_op_layout ("nn.conv2d" )
212119def convert_conv2d (attrs , inputs , tinfos , desired_layout ):
213120 """Convert Layout pass registration for conv2d op.
@@ -248,8 +155,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
248155 return relay .nn .conv2d (data , weight , ** new_attrs )
249156 return None
250157
251- reg .register_pattern ("nn.conv2d" , OpPattern .OUT_ELEMWISE_FUSABLE )
252-
253158
254159# conv2d_transpose
255160reg .register_strategy ("nn.conv2d_transpose" , strategy .conv2d_transpose_strategy )
@@ -421,36 +326,9 @@ def compute_mirror_pad(attrs, inputs, out_dtype, target):
421326reg .register_strategy_broadcast ("nn.mirror_pad" )
422327
423328
424- # winograd related operators
425- @reg .register_compute ("nn.contrib_conv2d_winograd_without_weight_transform" )
426- def compute_contrib_conv2d_winograd_without_weight_transform (attrs , inputs , out_dtype ):
427- """Compute definition of conv2d_winograd_without_weight_transform"""
428- # pylint: disable=assignment-from-no-return
429- padding = attrs .get_int_tuple ("padding" )
430- strides = attrs .get_int_tuple ("strides" )
431- dilation = attrs .get_int_tuple ("dilation" )
432- groups = attrs .get_int ("groups" )
433- data_layout = attrs .get_str ("data_layout" )
434- out_dtype = attrs .get_str ("out_dtype" )
435- tile_size = attrs .get_int ("tile_size" )
436- out_dtype = inputs [0 ].dtype if out_dtype == "" else out_dtype
437- assert dilation == (1 , 1 ), "Do not support dilate now"
438- assert groups == 1 , "Do not supoort arbitrary group number"
439-
440- out = topi .nn .conv2d_winograd_without_weight_transform (
441- inputs [0 ], inputs [1 ], strides , padding , dilation , data_layout ,
442- out_dtype , tile_size )
443-
444- return [out ]
445-
446-
447- # @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
448- # def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
449- # """Schedule definition of conv2d_winograd_without_weight_transform"""
450- # with target:
451- # return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
452-
453-
329+ # conv2d_winograd related operators
330+ reg .register_strategy ("nn.contrib_conv2d_winograd_without_weight_transform" ,
331+ strategy .conv2d_winograd_without_weight_transfrom_strategy )
454332reg .register_pattern ("nn.contrib_conv2d_winograd_without_weight_transform" ,
455333 OpPattern .OUT_ELEMWISE_FUSABLE )
456334
@@ -462,14 +340,8 @@ def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
462340 inputs [0 ], attrs .get_int ('tile_size' ))
463341 return [out ]
464342
465-
466- # @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
467- # def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
468- # """Schedule definition of contrib_conv2d_winograd_weight_transform"""
469- # with target:
470- # return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
471-
472-
343+ reg .register_schedule ("nn.contrib_conv2d_winograd_weight_transform" ,
344+ strategy .schedule_conv2d_winograd_weight_transform )
473345reg .register_pattern ("nn.contrib_conv2d_winograd_weight_transform" ,
474346 OpPattern .OUT_ELEMWISE_FUSABLE )
475347
@@ -535,31 +407,8 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
535407 OpPattern .OUT_ELEMWISE_FUSABLE )
536408
537409# depthwise_conv2d_NCHWc
538- @reg .register_compute ("nn.contrib_depthwise_conv2d_NCHWc" )
539- def compute_contrib_depthwise_conv2d_NCHWc (attrs , inputs , out_dtype , target ):
540- """Compute definition of depthwise conv2d NCHWc"""
541- # pylint: disable=assignment-from-no-return
542- padding = attrs .get_int_tuple ("padding" )
543- strides = attrs .get_int_tuple ("strides" )
544- dilation = attrs .get_int_tuple ("dilation" )
545- data_layout = attrs .get_str ("data_layout" )
546- out_layout = attrs .get_str ("out_layout" )
547- out_dtype = attrs .get_str ("out_dtype" )
548- out_dtype = inputs [0 ].dtype if out_dtype == "" else out_dtype
549-
550- out = topi .nn .depthwise_conv2d_NCHWc (inputs [0 ], inputs [1 ], strides , padding , dilation ,
551- data_layout , out_layout , out_dtype )
552- return [out ]
553-
554-
555- # @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
556- # def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
557- # """Schedule definition of contrib_conv2d_NCHWc"""
558- # with target:
559- # return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
560-
561-
562- reg .register_strategy ("nn.contrib_depthwise_conv2d_NCHWc" , strategy .depthwise_conv2d_NCHWc_strategy )
410+ reg .register_strategy ("nn.contrib_depthwise_conv2d_NCHWc" ,
411+ strategy .depthwise_conv2d_NCHWc_strategy )
563412reg .register_pattern ("nn.contrib_depthwise_conv2d_NCHWc" ,
564413 OpPattern .OUT_ELEMWISE_FUSABLE )
565414
@@ -658,7 +507,9 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
658507reg .register_pattern ("nn.space_to_depth" , OpPattern .INJECTIVE )
659508
660509
661- ############################### shape func #################################
510+ #####################
511+ # Shape functions #
512+ #####################
662513
663514@script
664515def _conv2d_NCHWc_shape_func (dshape , kshape , strides , padding , dilation , oc_bn ):
0 commit comments