Skip to content

Commit 4df9991

Browse files
committed
[cuDNN] Add cuDNN grouped convolutions support
Signed-off-by: Wei Pan <weip@nvidia.com>
1 parent 3f03869 commit 4df9991

File tree

9 files changed

+170
-85
lines changed

9 files changed

+170
-85
lines changed

python/tvm/contrib/cudnn.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ def conv_output_shape(tensor_format,
182182
x_shape,
183183
w_shape,
184184
data_dtype,
185-
conv_dtype):
185+
conv_dtype,
186+
groups=1):
186187
"""Get output shape of 2D or 3D convolution
187188
188189
Paramters
@@ -205,6 +206,8 @@ def conv_output_shape(tensor_format,
205206
data type
206207
conv_dtype: str
207208
convolution type
209+
groups: int
210+
number of groups
208211
209212
Returns
210213
-------
@@ -228,7 +231,8 @@ def conv_output_shape(tensor_format,
228231
_get_np_int32_array_handle(wshape),
229232
_get_np_int32_array_handle(oshape),
230233
data_dtype,
231-
conv_dtype)
234+
conv_dtype,
235+
groups)
232236
return list(oshape)
233237

234238

@@ -240,7 +244,8 @@ def conv_find_algo(tensor_format,
240244
w_shape,
241245
y_shape,
242246
data_dtype,
243-
conv_dtype):
247+
conv_dtype,
248+
groups=1):
244249
"""Choose the best algo for the given input.
245250
246251
Paramters
@@ -265,6 +270,8 @@ def conv_find_algo(tensor_format,
265270
data type
266271
conv_dtype: str
267272
convolution type
273+
groups: int
274+
number of groups
268275
269276
Returns
270277
-------
@@ -287,7 +294,8 @@ def conv_find_algo(tensor_format,
287294
_get_np_int32_array_handle(wshape),
288295
_get_np_int32_array_handle(yshape),
289296
data_dtype,
290-
conv_dtype)
297+
conv_dtype,
298+
groups)
291299

292300

293301
def conv_forward(x,
@@ -298,7 +306,8 @@ def conv_forward(x,
298306
conv_mode,
299307
tensor_format,
300308
algo,
301-
conv_dtype):
309+
conv_dtype,
310+
groups=1):
302311
"""Create an extern op that compute 2D or 3D convolution with CuDNN
303312
304313
Parameters
@@ -325,6 +334,8 @@ def conv_forward(x,
325334
if algo == -1, the best algo will be chosen by CUDNN
326335
conv_dtype: str
327336
convolution type
337+
groups: int
338+
the number of groups
328339
329340
Returns
330341
-------
@@ -335,8 +346,7 @@ def conv_forward(x,
335346
assert dims in (4, 5)
336347

337348
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
338-
pad, stride, dilation, _, _ = \
339-
_prepare_global_func_params(dims - 2, pad, stride, dilation)
349+
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)
340350

341351
oshape = conv_output_shape(tensor_format,
342352
pad,
@@ -345,7 +355,8 @@ def conv_forward(x,
345355
list(x.shape),
346356
list(w.shape),
347357
x.dtype,
348-
conv_dtype)
358+
conv_dtype,
359+
groups)
349360
if algo == -1:
350361
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
351362
# using INT8 data type, CuDNN will crash down.
@@ -361,7 +372,8 @@ def conv_forward(x,
361372
list(w.shape),
362373
oshape,
363374
x.dtype,
364-
conv_dtype)
375+
conv_dtype,
376+
groups)
365377

366378
if dims == 4:
367379
return te.extern(
@@ -380,7 +392,8 @@ def conv_forward(x,
380392
ins[0],
381393
ins[1],
382394
outs[0],
383-
conv_dtype), name="y")
395+
conv_dtype,
396+
groups), name="y")
384397

385398
return te.extern(
386399
oshape, [x, w],
@@ -401,7 +414,8 @@ def conv_forward(x,
401414
ins[0],
402415
ins[1],
403416
outs[0],
404-
conv_dtype), name="y")
417+
conv_dtype,
418+
groups), name="y")
405419

406420
def softmax(x, axis=-1):
407421
"""Compute softmax using CuDNN

python/tvm/relay/op/strategy/cuda.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
161161
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
162162
padding[1] == padding[3]:
163163
strategy.add_implementation(
164-
wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True),
164+
wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
165+
need_data_layout=True,
166+
has_groups=True),
165167
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
166168
name="conv2d_cudnn.cuda",
167169
plevel=15)
@@ -181,6 +183,20 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
181183
else:
182184
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
183185
else: # group_conv2d
186+
# add cudnn implementation, if any
187+
cudnn_impl = False
188+
if target.target_name == "cuda" and "cudnn" in target.libs:
189+
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
190+
padding[1] == padding[3]:
191+
strategy.add_implementation(
192+
wrap_compute_conv2d(topi.cuda.conv2d_cudnn,
193+
need_data_layout=True,
194+
has_groups=True),
195+
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn),
196+
name="conv2d_cudnn.cuda",
197+
plevel=15)
198+
cudnn_impl = True
199+
184200
if layout == 'NCHW':
185201
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
186202
assert kernel_layout == "OIHW"
@@ -194,7 +210,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
194210
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
195211
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
196212
name="group_conv2d_NCHWc_int8.cuda")
197-
else:
213+
elif not cudnn_impl:
198214
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
199215
return strategy
200216

src/runtime/contrib/cudnn/conv_forward.cc

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ void ConvolutionForward(
3535
int format,
3636
int algo,
3737
int dims,
38+
int groups,
3839
const int pad[],
3940
const int stride[],
4041
const int dilation[],
@@ -62,8 +63,10 @@ void ConvolutionForward(
6263

6364
// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
6465
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
66+
67+
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
6568
if (dims == 2) {
66-
// Set Desc
69+
// Set Desc
6770
CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc,
6871
pad[0],
6972
pad[1],
@@ -183,6 +186,7 @@ void ConvolutionForward(
183186
void OutputShape(
184187
int format,
185188
int dims,
189+
int groups,
186190
const int pad[],
187191
const int stride[],
188192
const int dilation[],
@@ -202,6 +206,7 @@ void OutputShape(
202206
int full_dims = dims + 2;
203207

204208
// conv desc
209+
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
205210
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
206211
dims,
207212
pad,
@@ -240,6 +245,7 @@ void OutputShape(
240245
// Set Input
241246
std::vector<int> tensor_stride(full_dims);
242247
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
248+
243249
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc,
244250
data_type,
245251
full_dims,
@@ -264,6 +270,7 @@ void OutputShape(
264270
void FindAlgo(
265271
int format,
266272
int dims,
273+
int groups,
267274
const int pad[],
268275
const int stride[],
269276
const int dilation[],
@@ -284,6 +291,7 @@ void FindAlgo(
284291
int full_dims = dims + 2;
285292

286293
// conv desc
294+
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
287295
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc,
288296
dims,
289297
pad,
@@ -360,16 +368,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
360368
int algo = args[2];
361369
int pad_v[2], stride_v[2], dilation_v[2];
362370
for (int i = 0; i < 2; i++) {
363-
pad_v[i] = args[3 + i];
364-
stride_v[i] = args[5 + i];
365-
dilation_v[i] = args[7 + i];
371+
pad_v[i] = args[3 + i];
372+
stride_v[i] = args[5 + i];
373+
dilation_v[i] = args[7 + i];
366374
}
367375
DLTensor* x = args[9];
368376
DLTensor* w = args[10];
369377
DLTensor* y = args[11];
370378
std::string conv_dtype = args[12];
379+
int groups = args[13];
371380

372-
ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype);
381+
ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v,
382+
dilation_v, x, w, y, conv_dtype);
373383
});
374384

375385

@@ -380,17 +390,18 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
380390
int algo = args[2];
381391
int pad_v[3], stride_v[3], dilation_v[3];
382392
for (int i = 0; i < 3; i++) {
383-
pad_v[i] = args[3 + i];
384-
stride_v[i] = args[6 + i];
385-
dilation_v[i] = args[9 + i];
393+
pad_v[i] = args[3 + i];
394+
stride_v[i] = args[6 + i];
395+
dilation_v[i] = args[9 + i];
386396
}
387397
DLTensor *x = args[12];
388398
DLTensor *w = args[13];
389399
DLTensor *y = args[14];
390400
std::string conv_dtype = args[15];
401+
int groups = args[16];
391402

392-
ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y,
393-
conv_dtype);
403+
ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v,
404+
dilation_v, x, w, y, conv_dtype);
394405
});
395406

396407

@@ -406,8 +417,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape")
406417
void* out_shape = args[7];
407418
std::string data_dtype = args[8];
408419
std::string conv_dtype = args[9];
420+
int groups = args[10];
409421

410-
OutputShape(format, dims, pad, stride, dilation, x_dim,
422+
OutputShape(format, dims, groups, pad, stride, dilation, x_dim,
411423
w_dim, out_shape, data_dtype, conv_dtype);
412424
});
413425

@@ -424,8 +436,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
424436
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
425437
std::string data_dtype = args[8];
426438
std::string conv_dtype = args[9];
439+
int groups = args[10];
427440

428-
FindAlgo(format, dims, pad, stride, dilation, x_dim,
441+
FindAlgo(format, dims, groups, pad, stride, dilation, x_dim,
429442
w_dim, y_dim, data_dtype, conv_dtype, ret);
430443
});
431444

src/runtime/contrib/cudnn/cudnn_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ struct ConvEntry {
7878
runtime::DeviceAPI *cuda_api;
7979
void *workspace{nullptr};
8080
size_t workspace_size{0};
81-
int group_count {0};
8281
ConvEntry();
8382
~ConvEntry();
8483
void UpdateWorkspace(const size_t wsize);

0 commit comments

Comments
 (0)