|
| 1 | +# CUDNN_RNN_RELU: Stock RNN with ReLu activation |
| 2 | +# CUDNN_RNN_TANH: Stock RNN with tanh activation |
| 3 | +# CUDNN_LSTM: LSTM with no peephole connections |
| 4 | +# CUDNN_GRU: Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) |
| 5 | + |
| 6 | +# param layout: |
| 7 | +# RNN: [weight, bias] × [input, hidden] |
| 8 | +# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem] |
| 9 | +# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] |
| 10 | + |
| 11 | +using LinearAlgebra |
| 12 | + |
| 13 | +function params(w::CuVector, input, hidden, n = 1) |
| 14 | + slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape) |
| 15 | + wx = slice(0, (input, hidden*n)) |
| 16 | + wh = slice(length(wx), (hidden, hidden*n)) |
| 17 | + bias = view(w, length(wx)+length(wh) .+ (1:hidden*n)) |
| 18 | + (wx, wh), bias |
| 19 | +end |
| 20 | + |
| 21 | +mutable struct RNNDesc{T} |
| 22 | + mode::cudnnRNNMode_t |
| 23 | + input::Int |
| 24 | + hidden::Int |
| 25 | + params::CuVector{T} |
| 26 | + weights::NTuple{2,CuMatrix{T}} |
| 27 | + bias::CuVector{T} |
| 28 | + ptr::Ptr{Nothing} |
| 29 | +end |
| 30 | + |
| 31 | +Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr |
| 32 | + |
| 33 | +function rnnParamSize(T, r, input) |
| 34 | + size = Csize_t[0] |
| 35 | + cudnnGetRNNParamsSize(handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) |
| 36 | + return Int(size[])÷sizeof(T) |
| 37 | +end |
| 38 | + |
| 39 | +ngates(mode) = [1, 1, 4, 3][mode+1] |
| 40 | +ngates(r::RNNDesc) = ngates(r.mode) |
| 41 | + |
| 42 | +function RNNDesc{T}(mode::cudnnRNNMode_t, input::Int, hidden::Int; layers = 1) where T |
| 43 | + d = [C_NULL] |
| 44 | + cudnnCreateRNNDescriptor(d) |
| 45 | + |
| 46 | + dropoutDesc = DropoutDesc(0) |
| 47 | + inputMode = CUDNN_LINEAR_INPUT |
| 48 | + direction = CUDNN_UNIDIRECTIONAL |
| 49 | + algo = CUDNN_RNN_ALGO_STANDARD |
| 50 | + cudnnSetRNNDescriptor_v6(handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) |
| 51 | + |
| 52 | + w = CUDA.zeros(T, rnnParamSize(T, d[], input)) |
| 53 | + # TODO: avoid reserve allocation here |
| 54 | + rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) |
| 55 | + finalizer(rd) do x |
| 56 | + cudnnDestroyRNNDescriptor(x) |
| 57 | + end |
| 58 | + return rd |
| 59 | +end |
| 60 | + |
| 61 | +function setweights!(d::RNNDesc, Wi, Wh, b) |
| 62 | + transpose!(d.weights[1], Wi) |
| 63 | + transpose!(d.weights[2], Wh) |
| 64 | + copyto!(d.bias, b) |
| 65 | + return |
| 66 | +end |
| 67 | + |
| 68 | +function cudnnGetRNNTrainingReserveSize(r::RNNDesc, seqlen, xdesc) |
| 69 | + size = Csize_t[0] |
| 70 | + cudnnGetRNNTrainingReserveSize(handle(), r, seqlen, xdesc, size) |
| 71 | + return Int(size[]) |
| 72 | +end |
| 73 | + |
| 74 | +function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, |
| 75 | + ho, cod, co, reserve=nothing) where T |
| 76 | + @workspace size=@argout( |
| 77 | + cudnnGetRNNWorkspaceSize(handle(), rnn, seqlen, xd, |
| 78 | + out(Ref{Csize_t}())) |
| 79 | + )[] workspace->begin |
| 80 | + if reserve == nothing |
| 81 | + cudnnRNNForwardInference(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, |
| 82 | + hod, ho, cod, co, workspace, sizeof(workspace)) |
| 83 | + else |
| 84 | + cudnnRNNForwardTraining(handle(), rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, |
| 85 | + hod, ho, cod, co, workspace, sizeof(workspace), |
| 86 | + reserve, sizeof(reserve)) |
| 87 | + end |
| 88 | + end |
| 89 | +end |
| 90 | + |
| 91 | +xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] |
| 92 | + |
| 93 | +hDesc(h::Nothing) = C_NULL, CU_NULL |
| 94 | +hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) |
| 95 | +function hDesc(h::DenseCuArray) |
| 96 | + TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h |
| 97 | +end |
| 98 | + |
| 99 | +# TODO: can we just manipulate strides here? |
| 100 | +# TODO: should use repmat, but this isn't implemented. |
| 101 | +hBatch(x::AbstractVector, h::CuVector) = h |
| 102 | +hBatch(x::AbstractMatrix, h::CuVector) = h .* CUDA.ones(1, size(x, 2)) |
| 103 | +hBatch(x::AbstractMatrix, h::CuMatrix) = h .* CUDA.ones(1, size(h,2) == 1 ? size(x,2) : 1) |
| 104 | + |
| 105 | +function forward(rnn::RNNDesc{T}, x::DenseCuArray{T}, h_::DenseCuArray{T}, c_ = nothing, train = Val{false}) where T |
| 106 | + h = hBatch(x, h_) |
| 107 | + c = c_ == nothing ? nothing : hBatch(x, c_) |
| 108 | + @assert size(x, 1) == rnn.input |
| 109 | + @assert size(h, 1) == rnn.hidden |
| 110 | + @assert size(x, 2) == size(h, 2) |
| 111 | + seqLength = 1 |
| 112 | + xdesc = xDesc(x) |
| 113 | + y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) |
| 114 | + ho = similar(h) |
| 115 | + ydesc = xDesc(y) |
| 116 | + reserve = train == Val{true} ? |
| 117 | + CuVector{UInt8}(undef, cudnnGetRNNTrainingReserveSize(rnn, seqLength, xdesc)) : |
| 118 | + nothing |
| 119 | + co = c == nothing ? c : similar(c) |
| 120 | + cudnnRNNForward(rnn, seqLength, |
| 121 | + xdesc, x, |
| 122 | + hDesc(h)..., |
| 123 | + hDesc(c)..., |
| 124 | + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, |
| 125 | + ydesc, y, |
| 126 | + hDesc(ho)..., |
| 127 | + hDesc(co)..., |
| 128 | + reserve) |
| 129 | + result = c == nothing ? (y, ho) : (y, ho, co) |
| 130 | + return train == Val{true} ? (reserve, result) : result |
| 131 | +end |
| 132 | + |
| 133 | +forwardTrain(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}, c = nothing) where T = |
| 134 | + forward(rnn, x, h, c, Val{true}) |
| 135 | + |
| 136 | +function cudnnRNNBackwardData(rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, |
| 137 | + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, |
| 138 | + dx, dhxDesc, dhx, dcxDesc, dcx, reserve) |
| 139 | + @workspace size=@argout( |
| 140 | + cudnnGetRNNWorkspaceSize(handle(), rnnDesc, seqLength, dxDesc, |
| 141 | + out(Ref{Csize_t}())) |
| 142 | + )[] workspace->begin |
| 143 | + cudnnRNNBackwardData(handle(), rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, |
| 144 | + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, |
| 145 | + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, sizeof(workspace), |
| 146 | + reserve, sizeof(reserve)) |
| 147 | + end |
| 148 | +end |
| 149 | + |
| 150 | +function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T |
| 151 | + # Same as above, any more efficient way? |
| 152 | + dy = dy_ isa Integer ? zero(y) : dy_ |
| 153 | + yd = xDesc(y) |
| 154 | + dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) |
| 155 | + dh = similar(h) |
| 156 | + dc = c == nothing ? nothing : similar(c) |
| 157 | + cudnnRNNBackwardData(rnn, 1, yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., |
| 158 | + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, hDesc(h)..., |
| 159 | + hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., reserve) |
| 160 | + return c == nothing ? (dx, dh) : (dx, dh, dc) |
| 161 | +end |
| 162 | + |
| 163 | +backwardData(rnn, y, dy, dho, hx, reserve) = |
| 164 | + backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) |
| 165 | + |
| 166 | +function cudnnRNNBackwardWeights(rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, |
| 167 | + y, dwDesc, dw, reserve) |
| 168 | + @workspace size=@argout( |
| 169 | + cudnnGetRNNWorkspaceSize(handle(), rnnDesc, seqLength, xDesc, |
| 170 | + out(Ref{Csize_t}())) |
| 171 | + )[] workspace->begin |
| 172 | + cudnnRNNBackwardWeights(handle(), rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, |
| 173 | + y, workspace, sizeof(workspace), dwDesc, dw, |
| 174 | + reserve, sizeof(reserve)) |
| 175 | + end |
| 176 | +end |
| 177 | + |
| 178 | +function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T |
| 179 | + dw = zero(rnn.params) |
| 180 | + cudnnRNNBackwardWeights(rnn, 1, xDesc(x), x, hDesc(h)..., xDesc(y), y, |
| 181 | + FilterDesc(T, (1, 1, length(dw))), dw, reserve) |
| 182 | + return params(dw, rnn.input, rnn.hidden, ngates(rnn)) |
| 183 | +end |
| 184 | + |
| 185 | +function pullback(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}) where T <: Union{Float32,Float64} |
| 186 | + reserve, (y, ho) = CUDNN.forwardTrain(rnn, x, h) |
| 187 | + return (y, ho), function (dy, dho) |
| 188 | + h_ = CUDNN.hBatch(x, h) |
| 189 | + dx, dh = CUDNN.backwardData(rnn, y, dy, dho, h_, reserve) |
| 190 | + (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) |
| 191 | + return (x = dx, h = dh, Wi = dWi, Wh = dWh, b = db) |
| 192 | + end |
| 193 | +end |
| 194 | + |
| 195 | +function pullback(rnn::RNNDesc{T}, x::DenseCuArray{T}, h::DenseCuArray{T}, c::DenseCuArray{T}) where T <: Union{Float32,Float64} |
| 196 | + reserve, (y, ho, co) = CUDNN.forwardTrain(rnn, x, h, c) |
| 197 | + return (y, ho, co), function (dy, dho, dco) |
| 198 | + h_ = CUDNN.hBatch(x, h) |
| 199 | + c_ = CUDNN.hBatch(x, c) |
| 200 | + dx, dh, dc = CUDNN.backwardData(rnn, y, dy, dho, dco, h_, c_, reserve) |
| 201 | + (dWi, dWh), db = CUDNN.backwardWeights(rnn, x, h_, y, reserve) |
| 202 | + return (x = dx, h = dh, c = dc, Wi = dWi, Wh = dWh, b = db) |
| 203 | + end |
| 204 | +end |
0 commit comments