Skip to content

Commit

Permalink
don't rely on cudnn for compilation (#10495)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwang2009 authored Mar 4, 2022
1 parent 3f96f3d commit db8cf2f
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,18 +741,22 @@ def conv_backward_data(
tensor_format, pad, stride, dilation, dy.shape, w.shape, output_padding, groups
)

algo = conv_backward_data_find_algo(
tensor_format,
pad,
stride,
dilation,
list(dy.shape),
list(w.shape),
dx_shape,
dy.dtype,
conv_dtype,
groups,
)
if exists():
# When cudnn exists, find the backward data algo
algo = conv_backward_data_find_algo(
tensor_format,
pad,
stride,
dilation,
list(dy.shape),
list(w.shape),
dx_shape,
dy.dtype,
conv_dtype,
groups,
)
else:
algo = 1

return te.extern(
dx_shape,
Expand Down

0 comments on commit db8cf2f

Please sign in to comment.