-
-
Notifications
You must be signed in to change notification settings - Fork 615
Closed
Description
I'm implementing Listen, Attend and Spell: arxiv.org/abs/1508.01211 work in Flux and hitting a bug related to the handling of Adjoint. I tried really hard, but unfortunately, I was not able to reduce my minimum example to a few lines. I made it reproducible nevertheless.
Furthermore, if you replace ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)' with ϕSᵢᵀ = permutedims(m.attention_ϕ(m.state.decoding)) in the code below, the error goes away.
using Flux
using Flux: flip, reset!, onecold, throttle, train!, @treelike, @epochs
using IterTools
using LinearAlgebra
import Base.Iterators
# Bidirectional LSTM
struct BLSTM{L,D}
forward :: L
backward :: L
dense :: D
end
@treelike BLSTM
function BLSTM(in::Integer, hidden::Integer, out::Integer, σ = identity)
forward = LSTM(in, hidden)
backward = LSTM(in, hidden)
dense = Dense(2hidden, out, σ)
return BLSTM(forward, backward, dense)
end
BLSTM(D_in::Integer, D_out::Integer) = BLSTM(D_in, ceil(Int, (D_in + D_out)/2), D_out)
(m::BLSTM)(xs::AbstractVector{<:AbstractVecOrMat})::AbstractVector{<:AbstractVecOrMat} = m.dense.(vcat.(m.forward.(xs), flip(m.backward, xs)))
# Flux.reset!(m::BLSTM) = reset!((m.forward, m.backward)) # not needed as taken care of by @treelike
function restack(xs::VV)::VV where VV <: AbstractVector{<:AbstractVecOrMat}
T = length(xs)
return vcat.(xs[1:2:T], xs[2:2:T])
end
"""
PBLSTM(D_in::Integer, D_out::Integer)
Pyramidal BLSTM is the same as BLSTM, with the addition that the outputs of BLSTM are concatenated at consecutive steps.
"""
function PBLSTM(D_in::Integer, D_out::Integer)
iseven(D_out) || throw("output dimension of the pyramidal BLSTM layer must be even")
D_out_blstm = Int(D_out/2)
return Chain(BLSTM(D_in, D_out_blstm), restack)
end
"""
Encoder(layer_sizes)
Encoder(D_in::Integer, D_out::Integer; nlayers::Integer = 4)
Encoder(D_in::Integer, D_out::Integer, hidden_sizes)
Encoder that consists of block of PBLSTMs. It accepts filter bank spectra as inputs and acts as acoustic model encoder.
"""
function Encoder(layer_sizes)
(length(layer_sizes) ≥ 3) || throw("number of layers of Encoder must be ≥ 2")
layer_dims = Tuple(partition(layer_sizes, 2, 1))
layers = ( PBLSTM(D_in, D_out) for (D_in, D_out) ∈ layer_dims[1:end-1] )
model = Chain(layers..., BLSTM(layer_dims[end]...))
return model
end
function Encoder(D_in::Integer, D_out::Integer; nlayers::Integer = 4)
layer_sizes = range(D_in, D_out; length=nlayers+1)
layer_sizes = map(layer_sizes) do x
n = ceil(Int, x)
n = iseven(n) ? n : (n + 1)
return n
end
layer_sizes[1] = D_in
layer_sizes[end] = D_out
return Encoder(layer_sizes)
end
Encoder(D_in::Integer, D_out::Integer, hidden_sizes) = Encoder((D_in, hidden_sizes..., D_out))
function MLP(layer_sizes, σs)
layers = Tuple(Dense(D_in, D_out, σ) for ((D_in, D_out), σ) ∈ zip(partition(layer_sizes, 2, 1), σs))
model = length(layers) == 1 ? first(layers) : Chain(layers...)
return model
end
function MLP(layer_sizes, σ::Function)
σs = ntuple(i -> σ, length(layer_sizes))
return MLP(layer_sizes, σs)
end
function MLP(D_in::Integer, D_out::Integer, σs)
layer_sizes = ceil.(Int, range(D_in, D_out; length=length(σs)+1))
return MLP(layer_sizes, σs)
end
function MLP(D_in::Integer, D_out::Integer, σ::Function=identity; nlayers::Integer = 1)
σs = ntuple(i -> σ, nlayers)
return MLP(D_in, D_out, σs)
end
function Decoder(layer_sizes)
layers = ( LSTM(D_in, D_out) for (D_in, D_out) ∈ partition(layer_sizes, 2, 1) )
model = Chain(layers...)
return model
end
function Decoder(D_in::Integer, D_out::Integer; nlayers::Integer = 2)
layer_sizes = ceil.(Int, range(D_in, D_out; length=nlayers+1))
return Decoder(layer_sizes)
end
Decoder(D_in::Integer, D_out::Integer, hidden_sizes) = Decoder((D_in, hidden_sizes..., D_out))
function CharacterDistribution(D_in::Integer, D_out::Integer, σ::Function; nlayers::Integer, applylog::Bool=true)
f = applylog ? logsoftmax : softmax
layer_sizes = ceil.(Int, range(D_in, D_out; length=nlayers+1))
layer_dims = Tuple(partition(layer_sizes, 2, 1))
layers = ( Dense(D_in, D_out, σ) for (D_in, D_out) ∈ layer_dims[1:end-1] )
return Chain(layers..., Dense(layer_dims[end]...), f)
end
CharacterDistribution(D_in::Integer, D_out::Integer; applylog::Bool=true) = Chain(Dense(D_in, D_out), applylog ? logsoftmax : softmax)
mutable struct State{M <: AbstractMatrix{<:Real}}
context :: M # last attention context
decoding :: M # last decoder state
prediction :: M # last prediction
# reset values
context₀ :: M
decoding₀ :: M
prediction₀ :: M
end
@treelike State
function State(D_c::Integer, D_d::Integer, D_p::Integer)
context₀ = param(zeros(Float32, D_c, 1))
decoding₀ = param(zeros(Float32, D_d, 1))
prediction₀ = param(zeros(Float32, D_p, 1))
return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end
function Flux.reset!(s::State)
s.context = s.context₀
s.decoding = s.decoding₀
s.prediction = s.prediction₀
return nothing
end
struct LAS{V, E, Dϕ, Dψ, L, C}
state :: State{V} # current state of the model
listen :: E # encoder function
attention_ϕ :: Dϕ # attention context function
attention_ψ :: Dψ # attention context function
spell :: L # RNN decoder
infer :: C # character distribution inference function
end
@treelike LAS
function LAS(D_in::Integer, D_out::Integer;
D_encoding::Integer,
D_attention::Integer,
D_decoding::Integer)
state = State(D_encoding, D_decoding, D_out)
listen = Encoder(D_in, D_encoding)
attention_ϕ = MLP(D_decoding, D_attention)
attention_ψ = MLP(D_encoding, D_attention)
spell = Decoder(D_encoding + D_decoding + D_out, D_decoding)
infer = CharacterDistribution(D_encoding + D_decoding, D_out)
las = LAS(state, listen, attention_ϕ, attention_ψ, spell, infer) |> gpu
return las
end
function (m::LAS{M})(xs::AbstractVector{<:AbstractMatrix}, maxT::Integer = length(xs))::AbstractArray{<:Real,3} where {M <: AbstractMatrix{<:Real}}
batch_size = size(first(xs), 2)
# compute input encoding
hs = m.listen(xs)
# concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
Hs = cat(hs...; dims=3)
# precompute ψ(H)
ψHs = m.attention_ψ.(hs)
# initialize prediction
ŷs = similar(xs, M, maxT)
# compute inital decoder state
O = gpu(zeros(Float32, size(m.state.decoding, 1), batch_size))
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
@inbounds for i ∈ eachindex(ŷs)
# compute ϕ(sᵢ)
ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)'
# ϕSᵢᵀ = collect(m.attention_ϕ(m.state.decoding)') # workaround for bug in encountered during training
# compute attention context
Eᵢs = diag.(Ref(ϕSᵢᵀ) .* ψHs)
αᵢs = softmax(vcat(Eᵢs'...))
# compute attention context, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
ŷs[i] = m.state.prediction
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
end
# concatenate sequence of D×N pprediction matrices into ssingle D×T×N 3-dimdimensional array
Ŷs = cat((ŷ -> reshape(ŷ, size(ŷ,1), 1, :)).(ŷs)...; dims=2)
reset!(m)
return Ŷs
end
function Flux.reset!(m::LAS)
reset!(m.state)
reset!(m.listen)
reset!(m.spell)
return nothing
end
Xs_train = [[rand(Float32, 16,3) for _ ∈ 1:32] for _ ∈ 1:4]
ys_train = [[[rand(1:4) for _ ∈ 1:32] for _ ∈ 1:3] for _ ∈ 1:4]
las = LAS(16, 4; D_encoding=8, D_attention=8, D_decoding=8)
function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, ys::AbstractVector{<:AbstractVector{<:Integer}}, maxT::Integer = length(xs))::Real
Ŷs = las(gpu.(xs), maxT)
x, y, z = size(Ŷs)
colsrng = range(0; step=x, length=y)
slicesrng = range(0; step=x*y, length=z)
# true_linindices = vcat([y .+ colsrng[eachindex(y)] .+ slicesrng[n] for (n, y) ∈ enumerate(ys)]...)
true_linindices = mapreduce((n, y) -> y .+ colsrng[eachindex(y)] .+ slicesrng[n], vcat, eachindex(ys), ys)
l = -sum(Ŷs[true_linindices])
return l
end
train!(loss, params(las), zip(Xs_train, ys_train), ADAM())running the last line resulst in:
ERROR: MethodError: Cannot `convert` an object of type Array{Float32,2} to an object of type Adjoint{Float32,Array{Float32,2}}
Closest candidates are:
convert(::Type{Adjoint{T,S}}, ::Adjoint) where {T, S} at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/adjtrans.jl:186
convert(::Type{T<:AbstractArray}, ::T<:AbstractArray) where T<:AbstractArray at abstractarray.jl:14
convert(::Type{T<:AbstractArray}, ::Factorization) where T<:AbstractArray at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/factorization.jl:46
...
Stacktrace:
[1] setproperty!(::Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}}, ::Symbol, ::Array{Float32,2}) at ./Base.jl:21
[2] back(::Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:53
[3] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[4] foreach at ./abstractarray.jl:1921 [inlined]
[5] back_(::Tracker.Call{getfield(Tracker, Symbol("##509#510")){TrackedArray{…,Adjoint{Float32,Array{Float32,2}}},TrackedArray{…,Array{Float32,2}}},Tuple{Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}},Tracker.Tracked{Array{Float32,2}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[6] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
... (the last 4 lines are repeated 2 more times)
[15] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Adjoint{Float32,Array{Float32,1}}}, ::Array{Float32,2}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[16] foreach(::Function, ::NTuple{4,Tracker.Tracked{Adjoint{Float32,Array{Float32,1}}}}, ::NTuple{4,Array{Float32,2}}) at ./abstractarray.jl:1921
... (the last 12 lines are repeated 1 more time)
[29] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,typeof(*),Tuple{TrackedArray{…,Array{Float32,3}},TrackedArray{…,Array{Float32,3}}}},Tuple{Tracker.Tracked{Array{Float32,3}},Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[30] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[31] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[32] back_(::Tracker.Call{getfield(Tracker, Symbol("##482#483")){TrackedArray{…,Array{Float32,3}}},Tuple{Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[33] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[34] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[35] foreach at ./abstractarray.jl:1921 [inlined]
[36] back_(::Tracker.Call{getfield(Tracker, Symbol("##460#461")){TrackedArray{…,Array{Float32,3}}},Tuple{Tracker.Tracked{Array{Float32,3}},Nothing}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[37] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[38] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[39] foreach(::Function, ::Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,2}}}, ::Tuple{Array{Float32,2},Array{Float32,2}}) at ./abstractarray.jl:1921
... (the last 8 lines are repeated 1 more time)
[48] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(+)},typeof(identity)},Tuple{TrackedArray{…,Array{Float32,2}},TrackedArray{…,Array{Float32,1}}}},Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[49] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[50] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[51] back_(::Tracker.Call{getfield(Tracker, Symbol("##513#514")){TrackedArray{…,Array{Float32,2}}},Tuple{Tracker.Tracked{Array{Float32,2}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[52] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[53] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[54] foreach at ./abstractarray.jl:1921 [inlined]
[55] back_(::Tracker.Call{getfield(Tracker, Symbol("##460#461")){TrackedArray{…,Array{Float32,2}}},Tuple{Tracker.Tracked{Array{Float32,2}},Nothing}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[56] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[57] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[58] foreach(::Function, ::NTuple{32,Tracker.Tracked{Array{Float32,3}}}, ::NTuple{32,Array{Float32,3}}) at ./abstractarray.jl:1921
[59] back_(::Tracker.Call{getfield(Tracker, Symbol("##450#455")){Int64,NTuple{32,TrackedArray{…,Array{Float32,3}}}},NTuple{32,Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[60] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[61] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
[62] foreach at ./abstractarray.jl:1921 [inlined]
[63] back_(::Tracker.Call{getfield(Tracker, Symbol("##376#378")){TrackedArray{…,Array{Float32,3}},Tuple{Array{Int64,1}}},Tuple{Tracker.Tracked{Array{Float32,3}},Nothing}}, ::Array{Float32,1}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[64] back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float32,1}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[65] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
... (the last 3 lines are repeated 1 more time)
[69] back_(::Tracker.Call{getfield(Tracker, Symbol("##199#200")),Tuple{Tracker.Tracked{Float32}}}, ::Float32, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
[70] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
[71] back!(::Tracker.TrackedReal{Float32}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:77
[72] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss),Tuple{Array{Array{Float32,2},1},Array{Array{Int64,1},1}}}, ::Tracker.Params) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:4
[73] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:164
[74] gradient at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:164 [inlined]
[75] macro expansion at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
[76] macro expansion at /Users/Azamat/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
[77] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(train!), ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,2},1},1},Array{Array{Array{Int64,1},1},1}}}, ::ADAM) at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
[78] train!(::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,2},1},1},Array{Array{Array{Int64,1},1},1}}}, ::ADAM) at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:67Metadata
Metadata
Assignees
Labels
No labels