Skip to content

Commit b30ac5f

Browse files
authored
Merge pull request #738 from denizyuret/dy/rnncompat
copied the old rnn.jl->rnncompat.jl for Flux compatibility
2 parents d69be98 + 5c0c942 commit b30ac5f

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

lib/cudnn/CUDNN.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ include("normalization.jl")
3636
# high-level integrations
3737
include("nnlib.jl")
3838
include("batchnorm.jl")
39+
include("rnncompat.jl")
3940

4041

4142
function math_mode(mode=CUDA.math_mode())

lib/cudnn/rnncompat.jl

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# CUDNN_RNN_RELU: Stock RNN with ReLu activation
2+
# CUDNN_RNN_TANH: Stock RNN with tanh activation
3+
# CUDNN_LSTM: LSTM with no peephole connections
4+
# CUDNN_GRU: Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
5+
6+
# param layout:
7+
# RNN: [weight, bias] × [input, hidden]
8+
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
9+
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
10+
11+
using LinearAlgebra
12+
13+
function params(w::CuVector, input, hidden, n = 1)
14+
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
15+
wx = slice(0, (input, hidden*n))
16+
wh = slice(length(wx), (hidden, hidden*n))
17+
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
18+
(wx, wh), bias
19+
end
20+
21+
mutable struct RNNDesc{T}
22+
mode::cudnnRNNMode_t
23+
input::Int
24+
hidden::Int
25+
params::CuVector{T}
26+
weights::NTuple{2,CuMatrix{T}}
27+
bias::CuVector{T}
28+
ptr::Ptr{Nothing}
29+
end
30+
31+
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
32+
33+
function rnnParamSize(T, r, input)
34+
size = Csize_t[0]
35+
cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
36+
return Int(size[])÷sizeof(T)
37+
end
38+
39+
ngates(mode) = [1, 1, 4, 3][mode+1]
40+
ngates(r::RNNDesc) = ngates(r.mode)
41+
42+
function RNNDesc{T}(mode::cudnnRNNMode_t, input::Int, hidden::Int; layers = 1) where T
43+
d = [C_NULL]
44+
cudnnCreateRNNDescriptor(d)
45+
46+
dropoutDesc = DropoutDesc(0)
47+
inputMode = CUDNN_LINEAR_INPUT
48+
direction = CUDNN_UNIDIRECTIONAL
49+
algo = CUDNN_RNN_ALGO_STANDARD
50+
cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
51+
52+
w = CUDA.zeros(T, rnnParamSize(T, d[], input))
53+
# TODO: avoid reserve allocation here
54+
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
55+
finalizer(rd) do x
56+
cudnnDestroyRNNDescriptor(x)
57+
end
58+
return rd
59+
end
60+
61+
function setweights!(d::RNNDesc, Wi, Wh, b)
62+
transpose!(d.weights[1], Wi)
63+
transpose!(d.weights[2], Wh)
64+
copyto!(d.bias, b)
65+
return
66+
end
67+
68+
function cudnnGetRNNTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
69+
size = Csize_t[0]
70+
cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size)
71+
return Int(size[])
72+
end
73+
74+
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod,
75+
ho, cod, co, reserve=nothing) where T
76+
@workspace size=@argout(
77+
cudnnGetRNNWorkspaceSize(handle(), rnn, seqlen, xd,
78+
out(Ref{Csize_t}()))
79+
)[] workspace->begin
80+
if reserve == nothing
81+
cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
82+
hod, ho, cod, co, workspace, sizeof(workspace))
83+
else
84+
cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y,
85+
hod, ho, cod, co, workspace, sizeof(workspace),
86+
reserve, sizeof(reserve))
87+
end
88+
end
89+
end
90+
91+
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
92+
93+
hDesc(h::Nothing) = C_NULL, CU_NULL
94+
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
95+
function hDesc(h::DenseCuArray)
96+
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
97+
end
98+
99+
# TODO: can we just manipulate strides here?
100+
# TODO: should use repmat, but this isn't implemented.
101+
hBatch(x::AbstractVector, h::CuVector) = h
102+
hBatch(x::AbstractMatrix, h::CuVector) = h .* CUDA.ones(1, size(x, 2))
103+
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CUDA.ones(1, size(h,2) == 1 ? size(x,2) : 1)
104+
105+
function forward(rnn::RNNDesc{T}, x::DenseCuArray{T}, h_::DenseCuArray{T}, c_ = nothing, train = Val{false}) where T
106+
h = hBatch(x, h_)
107+
c = c_ == nothing ? nothing : hBatch(x, c_)
108+
@assert size(x, 1) == rnn.input
109+
@assert size(h, 1) == rnn.hidden
110+
@assert size(x, 2) == size(h, 2)
111+
seqLength = 1
112+
xdesc = xDesc(x)
113+
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
114+
ho = similar(h)
115+
ydesc = xDesc(y)
116+
reserve = train == Val{true} ?
117+
CuVector{UInt8}(undef, cudnnGetRNNTrainingReserveSize(rnn, seqLength, xdesc)) :
118+
nothing
119+
co = c == nothing ? c : similar(c)
120+
cudnnRNNForward(rnn, seqLength,
121+
xdesc, x,
122+
hDesc(h)...,
123+
hDesc(c)...,
124+
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
125+
ydesc, y,
126+
hDesc(ho)...,
127+
hDesc(co)...,
128+
reserve)
129+
result = c == nothing ? (y, ho) : (y, ho, co)
130+
return train == Val{true} ? (reserve, result) : result
131+
end
132+
133+
forwardTrain(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}, c = nothing) where T =
134+
forward(rnn, x, h, c, Val{true})
135+
136+
function cudnnRNNBackwardData(rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc,
137+
dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc,
138+
dx, dhxDesc, dhx, dcxDesc, dcx, reserve)
139+
@workspace size=@argout(
140+
cudnnGetRNNWorkspaceSize(handle(), rnnDesc, seqLength, dxDesc,
141+
out(Ref{Csize_t}()))
142+
)[] workspace->begin
143+
cudnnRNNBackwardData(handle(), rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc,
144+
dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc,
145+
dx, dhxDesc, dhx, dcxDesc, dcx, workspace, sizeof(workspace),
146+
reserve, sizeof(reserve))
147+
end
148+
end
149+
150+
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
151+
# Same as above, any more efficient way?
152+
dy = dy_ isa Integer ? zero(y) : dy_
153+
yd = xDesc(y)
154+
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
155+
dh = similar(h)
156+
dc = c == nothing ? nothing : similar(c)
157+
cudnnRNNBackwardData(rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
158+
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)...,
159+
hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., reserve)
160+
return c == nothing ? (dx, dh) : (dx, dh, dc)
161+
end
162+
163+
backwardData(rnn, y, dy, dho, hx, reserve) =
164+
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
165+
166+
function cudnnRNNBackwardWeights(rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc,
167+
y, dwDesc, dw, reserve)
168+
@workspace size=@argout(
169+
cudnnGetRNNWorkspaceSize(handle(), rnnDesc, seqLength, xDesc,
170+
out(Ref{Csize_t}()))
171+
)[] workspace->begin
172+
cudnnRNNBackwardWeights(handle(), rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc,
173+
y, workspace, sizeof(workspace), dwDesc, dw,
174+
reserve, sizeof(reserve))
175+
end
176+
end
177+
178+
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
179+
dw = zero(rnn.params)
180+
cudnnRNNBackwardWeights(rnn, 1, xDesc(x), x, hDesc(h)..., xDesc(y), y,
181+
FilterDesc(T, (1, 1, length(dw))), dw, reserve)
182+
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
183+
end
184+
185+
function pullback(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}) where T <: Union{Float32,Float64}
186+
reserve, (y, ho) = CUDNN.forwardTrain(rnn, x, h)
187+
return (y, ho), function (dy, dho)
188+
h_ = CUDNN.hBatch(x, h)
189+
dx, dh = CUDNN.backwardData(rnn, y, dy, dho, h_, reserve)
190+
(dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve)
191+
return (x = dx, h = dh, Wi = dWi, Wh = dWh, b = db)
192+
end
193+
end
194+
195+
function pullback(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}, c::DenseCuArray{T}) where T <: Union{Float32,Float64}
196+
reserve, (y, ho, co) = CUDNN.forwardTrain(rnn, x, h, c)
197+
return (y, ho, co), function (dy, dho, dco)
198+
h_ = CUDNN.hBatch(x, h)
199+
c_ = CUDNN.hBatch(x, c)
200+
dx, dh, dc = CUDNN.backwardData(rnn, y, dy, dho, dco, h_, c_, reserve)
201+
(dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve)
202+
return (x = dx, h = dh, c = dc, Wi = dWi, Wh = dWh, b = db)
203+
end
204+
end

0 commit comments

Comments
 (0)