Skip to content

Commit 5761515

Browse files
Merge pull request #9 from pxl-th/master
Add group support for convolutions
2 parents fe6a3ff + 21ea951 commit 5761515

File tree

3 files changed

+23
-29
lines changed

3 files changed

+23
-29
lines changed

ext/CUDAExt/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.3.1"
14-
NNlib = "0.7.23"
14+
NNlib = "0.7.25"
1515
julia = "1.6"
1616

1717
[extras]

ext/CUDAExt/src/cudnn/conv.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using NNlib: DenseConvDims
33
import NNlib: conv!, ∇conv_filter!, ∇conv_data!, conv_bias_act!
44

5-
using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
5+
using CUDA.CUDNN: scalingParameter, CUDNN_CONVOLUTION, convdims,
66
cudnnConvolutionDescriptor, cudnnConvolutionBwdDataAlgoPerf,
77
cudnnConvolutionForward!, cudnnConvolutionBwdFilterAlgoPerf,
88
cudnnConvolutionBackwardData, cudnnConvolutionBackwardFilter,
@@ -19,7 +19,7 @@ function cudnnConvolutionDescriptor(cdims::DenseConvDims, x::DenseCuArray{T}) wh
1919
cudnnDataType(T),
2020
math_mode(),
2121
CUDNN_DEFAULT_REORDER,
22-
Cint(1))
22+
Cint(NNlib.groupcount(cdims)))
2323
end
2424

2525
function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims::DenseConvDims;
@@ -34,15 +34,15 @@ function conv!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T}, cdims
3434
cudnnConvolutionForward!(y, w, x, d; alpha, beta, z=y)
3535
end
3636

37-
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
37+
function conv_bias_act!(y::DenseCuArray{T}, x::DenseCuArray{T}, w::DenseCuArray{T},
3838
cdims::DenseConvDims, bias::DenseCuArray{T}, σ=identity;
3939
z::DenseCuArray{T}=y, alpha=1, beta=0, algo=-1) where T<:CUDNNFloat
4040
if cudnnversion() < v"6"
4141
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
4242
end
4343
if algo != -1
4444
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
45-
end
45+
end
4646
d = cudnnConvolutionDescriptor(cdims, x)
4747
# only relu and identity are supported by cudnnConvolutionForward!
4848
activation === NNlib.relu ? CUDNN_ACTIVATION_RELU : CUDNN_ACTIVATION_IDENTITY)
@@ -60,7 +60,7 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
6060
end
6161
if algo != -1
6262
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
63-
end
63+
end
6464
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
6565
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
6666
convDesc = cudnnConvolutionDescriptor(cdims, dx)
@@ -78,7 +78,7 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
7878
end
7979
if algo != -1
8080
@warn "The algo option has been deprecated, the fastest algo is computed automatically" maxlog=1
81-
end
81+
end
8282
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
8383
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
8484
convDesc = cudnnConvolutionDescriptor(cdims, x)

ext/CUDAExt/test/conv.jl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@ using NNlib: DenseConvDims
99
@test ∇conv_filter(a, c, cdims) collect(∇conv_filter(da, dc, cdims))
1010

1111
# Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs
12-
for num_spatial_dims in (1, 2, 3)
12+
options = Dict{Any, Any}.((
13+
(), (:dilation => 2), (:flipkernel => true), (:stride => 2),
14+
(:padding => 1),
15+
))
16+
C_in_ = 3
17+
C_out = 4
18+
batch_size = 1
19+
20+
for groups in (1, 2, 4), num_spatial_dims in (1, 2, 3)
21+
# Make `C_in = C_out` when using grouped convolution.
22+
C_in = groups == 1 ? C_in_ : C_out
1323
# Initialize data we'll run our tests over
14-
C_in = 3
15-
C_out = 4
16-
batch_size = 1
1724
x = rand(Float64, fill(8, num_spatial_dims)..., C_in, batch_size)
18-
w = rand(Float64, fill(2, num_spatial_dims)..., C_in, C_out)
19-
b = rand(Float64, fill(1, num_spatial_dims)..., C_in, C_out)
20-
options = (Dict(), Dict(:dilation => 2), Dict(:flipkernel => true), Dict(:stride => 2), Dict(:padding => 1))
21-
22-
# @denizyuret: algo option deprecated for nnlib, handling in cudnn
23-
# algos = (1, 0, 1, 1,)
24-
# for (opts, algo) in zip(options, algos)
25+
w = rand(Float64, fill(2, num_spatial_dims)..., C_in ÷ groups, C_out)
2526

26-
for opts in options
27+
for opts in options
28+
opts[:groups] = groups
2729
cdims = DenseConvDims(x, w; opts...)
2830
y = NNlib.conv(x, w, cdims)
2931

@@ -36,19 +38,11 @@ using NNlib: DenseConvDims
3638
gputest((x, w) -> NNlib.conv(x, w, cdims; alpha=2.0), x, w, checkgrad=false) # TODO
3739
gputest((y, w) -> NNlib.∇conv_data(y, w, cdims; alpha=2.0), y, w, checkgrad=false) # TODO
3840
gputest((x, y) -> NNlib.∇conv_filter(x, y, cdims; alpha=2.0), x, y, checkgrad=false) # TODO
39-
41+
4042
gputest((y, x, w) -> NNlib.conv!(copy(y), x, w, cdims; beta=2.0), y, x, w, checkgrad=false) # TODO
4143
# @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO
4244
gputest((w, x, y) -> NNlib.∇conv_filter!(copy(w), x, y, cdims; beta=2.0), w, x, y, checkgrad=false) # TODO
4345
end
44-
45-
# CPU implementation of ∇conv_bias!
46-
db = zeros(Float64, 1, 1, 3, 1)
47-
dy = randn(Float64, 8, 8, 3, 1)
48-
function NNlibCUDA.∇conv_bias!(db, dy)
49-
db .= sum(dy, dims=1:(ndims(dy)-2))
50-
return db
51-
end
52-
gputest(NNlibCUDA.∇conv_bias!, db, dy, checkgrad=false)
5346
end
47+
5448
end

0 commit comments

Comments
 (0)