Skip to content

Strange bug with Adjoint #866

@AzamatB

Description

@AzamatB

I'm implementing Listen, Attend and Spell: arxiv.org/abs/1508.01211 work in Flux and hitting a bug related to the handling of Adjoint. I tried really hard, but unfortunately, I was not able to reduce my minimum example to a few lines. I made it reproducible nevertheless.

Furthermore, if you replace ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)' with ϕSᵢᵀ = permutedims(m.attention_ϕ(m.state.decoding)) in the code below, the error goes away.

using Flux
using Flux: flip, reset!, onecold, throttle, train!, @treelike, @epochs
using IterTools
using LinearAlgebra
import Base.Iterators

# Bidirectional LSTM
struct BLSTM{L,D}
   forward  :: L
   backward :: L
   dense    :: D
end

@treelike BLSTM

function BLSTM(in::Integer, hidden::Integer, out::Integer, σ = identity)
   forward  = LSTM(in, hidden)
   backward = LSTM(in, hidden)
   dense    = Dense(2hidden, out, σ)
   return BLSTM(forward, backward, dense)
end

BLSTM(D_in::Integer, D_out::Integer) = BLSTM(D_in, ceil(Int, (D_in + D_out)/2), D_out)

(m::BLSTM)(xs::AbstractVector{<:AbstractVecOrMat})::AbstractVector{<:AbstractVecOrMat} = m.dense.(vcat.(m.forward.(xs), flip(m.backward, xs)))

# Flux.reset!(m::BLSTM) = reset!((m.forward, m.backward)) # not needed as taken care of by @treelike

function restack(xs::VV)::VV where VV <: AbstractVector{<:AbstractVecOrMat}
   T = length(xs)
   return vcat.(xs[1:2:T], xs[2:2:T])
end

"""
   PBLSTM(D_in::Integer, D_out::Integer)

Pyramidal BLSTM is the same as BLSTM, with the addition that the outputs of BLSTM are concatenated at consecutive steps.
"""
function PBLSTM(D_in::Integer, D_out::Integer)
   iseven(D_out) || throw("output dimension of the pyramidal BLSTM layer must be even")
   D_out_blstm = Int(D_out/2)
   return Chain(BLSTM(D_in, D_out_blstm), restack)
end

"""
   Encoder(layer_sizes)
   Encoder(D_in::Integer, D_out::Integer; nlayers::Integer = 4)
   Encoder(D_in::Integer, D_out::Integer, hidden_sizes)

Encoder that consists of block of PBLSTMs. It accepts filter bank spectra as inputs and acts as acoustic model encoder.
"""
function Encoder(layer_sizes)
   (length(layer_sizes)  3) || throw("number of layers of Encoder must be ≥ 2")
   layer_dims = Tuple(partition(layer_sizes, 2, 1))
   layers = ( PBLSTM(D_in, D_out) for (D_in, D_out)  layer_dims[1:end-1] )
   model = Chain(layers..., BLSTM(layer_dims[end]...))
   return model
end

function Encoder(D_in::Integer, D_out::Integer; nlayers::Integer = 4)
   layer_sizes = range(D_in, D_out; length=nlayers+1)
   layer_sizes = map(layer_sizes) do x
      n = ceil(Int, x)
      n = iseven(n) ? n : (n + 1)
      return n
   end
   layer_sizes[1]   = D_in
   layer_sizes[end] = D_out
   return Encoder(layer_sizes)
end

Encoder(D_in::Integer, D_out::Integer, hidden_sizes) = Encoder((D_in, hidden_sizes..., D_out))


function MLP(layer_sizes, σs)
   layers = Tuple(Dense(D_in, D_out, σ) for ((D_in, D_out), σ)  zip(partition(layer_sizes, 2, 1), σs))
   model = length(layers) == 1 ? first(layers) : Chain(layers...)
   return model
end

function MLP(layer_sizes, σ::Function)
   σs = ntuple(i -> σ, length(layer_sizes))
   return MLP(layer_sizes, σs)
end

function MLP(D_in::Integer, D_out::Integer, σs)
   layer_sizes = ceil.(Int, range(D_in, D_out; length=length(σs)+1))
   return MLP(layer_sizes, σs)
end

function MLP(D_in::Integer, D_out::Integer, σ::Function=identity; nlayers::Integer = 1)
   σs = ntuple(i -> σ, nlayers)
   return MLP(D_in, D_out, σs)
end


function Decoder(layer_sizes)
   layers = ( LSTM(D_in, D_out) for (D_in, D_out)  partition(layer_sizes, 2, 1) )
   model = Chain(layers...)
   return model
end

function Decoder(D_in::Integer, D_out::Integer; nlayers::Integer = 2)
   layer_sizes = ceil.(Int, range(D_in, D_out; length=nlayers+1))
   return Decoder(layer_sizes)
end

Decoder(D_in::Integer, D_out::Integer, hidden_sizes) = Decoder((D_in, hidden_sizes..., D_out))


function CharacterDistribution(D_in::Integer, D_out::Integer, σ::Function; nlayers::Integer, applylog::Bool=true)
   f = applylog ? logsoftmax : softmax
   layer_sizes = ceil.(Int, range(D_in, D_out; length=nlayers+1))
   layer_dims = Tuple(partition(layer_sizes, 2, 1))
   layers = ( Dense(D_in, D_out, σ) for (D_in, D_out)  layer_dims[1:end-1] )
   return Chain(layers..., Dense(layer_dims[end]...), f)
end

CharacterDistribution(D_in::Integer, D_out::Integer; applylog::Bool=true) = Chain(Dense(D_in, D_out), applylog ? logsoftmax : softmax)


mutable struct State{M <: AbstractMatrix{<:Real}}
   context     :: M   # last attention context
   decoding    :: M   # last decoder state
   prediction  :: M   # last prediction
   # reset values
   context₀    :: M
   decoding₀   :: M
   prediction₀ :: M
end

@treelike State

function State(D_c::Integer, D_d::Integer, D_p::Integer)
   context₀    = param(zeros(Float32, D_c, 1))
   decoding₀   = param(zeros(Float32, D_d, 1))
   prediction₀ = param(zeros(Float32, D_p, 1))
   return State(context₀, decoding₀, prediction₀, context₀, decoding₀, prediction₀)
end

function Flux.reset!(s::State)
   s.context    = s.context₀
   s.decoding   = s.decoding₀
   s.prediction = s.prediction₀
   return nothing
end


struct LAS{V, E, Dϕ, Dψ, L, C}
   state       :: State{V} # current state of the model
   listen      :: E   # encoder function
   attention_ϕ :: Dϕ  # attention context function
   attention_ψ :: Dψ  # attention context function
   spell       :: L   # RNN decoder
   infer       :: C   # character distribution inference function
end

@treelike LAS

function LAS(D_in::Integer, D_out::Integer;
             D_encoding::Integer,
             D_attention::Integer,
             D_decoding::Integer)
   state       = State(D_encoding, D_decoding, D_out)
   listen      = Encoder(D_in, D_encoding)
   attention_ϕ = MLP(D_decoding, D_attention)
   attention_ψ = MLP(D_encoding, D_attention)
   spell       = Decoder(D_encoding + D_decoding + D_out, D_decoding)
   infer       = CharacterDistribution(D_encoding + D_decoding, D_out)
   las = LAS(state, listen, attention_ϕ, attention_ψ, spell, infer) |> gpu
   return las
end

function (m::LAS{M})(xs::AbstractVector{<:AbstractMatrix}, maxT::Integer = length(xs))::AbstractArray{<:Real,3} where {M <: AbstractMatrix{<:Real}}
   batch_size = size(first(xs), 2)
   # compute input encoding
   hs = m.listen(xs)
   # concatenate sequence of D×N matrices into ssingle D×N×T 3-dimdimensional array
   Hs = cat(hs...; dims=3)
   # precompute ψ(H)
   ψHs = m.attention_ψ.(hs)
   # initialize prediction
   ŷs = similar(xs, M, maxT)
   # compute inital decoder state
   O = gpu(zeros(Float32, size(m.state.decoding, 1), batch_size))
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context]) .+ O
   @inbounds for i  eachindex(ŷs)
      # compute ϕ(sᵢ)
      ϕSᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # ϕSᵢᵀ = collect(m.attention_ϕ(m.state.decoding)') # workaround for bug in encountered during training
      # compute attention context
      Eᵢs = diag.(Ref(ϕSᵢᵀ) .* ψHs)
      αᵢs = softmax(vcat(Eᵢs'...))
      # compute attention context, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      ŷs[i] = m.state.prediction
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
   end
   # concatenate sequence of D×N pprediction matrices into ssingle D×T×N 3-dimdimensional array
   Ŷs = cat((ŷ -> reshape(ŷ, size(ŷ,1), 1, :)).(ŷs)...; dims=2)
   reset!(m)
   return Ŷs
end


function Flux.reset!(m::LAS)
   reset!(m.state)
   reset!(m.listen)
   reset!(m.spell)
   return nothing
end


Xs_train = [[rand(Float32, 16,3) for _  1:32] for _  1:4]
ys_train = [[[rand(1:4) for _  1:32] for _  1:3] for _  1:4]

las = LAS(16, 4; D_encoding=8, D_attention=8, D_decoding=8)

function loss(xs::AbstractVector{<:AbstractMatrix{<:Real}}, ys::AbstractVector{<:AbstractVector{<:Integer}}, maxT::Integer = length(xs))::Real
   Ŷs = las(gpu.(xs), maxT)
   x, y, z = size(Ŷs)
   colsrng = range(0; step=x, length=y)
   slicesrng = range(0; step=x*y, length=z)
   # true_linindices = vcat([y .+ colsrng[eachindex(y)] .+ slicesrng[n] for (n, y) ∈ enumerate(ys)]...)
   true_linindices = mapreduce((n, y) -> y .+ colsrng[eachindex(y)] .+ slicesrng[n], vcat, eachindex(ys), ys)
   l = -sum(Ŷs[true_linindices])
   return l
end

train!(loss, params(las), zip(Xs_train, ys_train), ADAM())

running the last line resulst in:

ERROR: MethodError: Cannot `convert` an object of type Array{Float32,2} to an object of type Adjoint{Float32,Array{Float32,2}}
Closest candidates are:
  convert(::Type{Adjoint{T,S}}, ::Adjoint) where {T, S} at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/adjtrans.jl:186
  convert(::Type{T<:AbstractArray}, ::T<:AbstractArray) where T<:AbstractArray at abstractarray.jl:14
  convert(::Type{T<:AbstractArray}, ::Factorization) where T<:AbstractArray at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.2/LinearAlgebra/src/factorization.jl:46
  ...
Stacktrace:
 [1] setproperty!(::Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}}, ::Symbol, ::Array{Float32,2}) at ./Base.jl:21
 [2] back(::Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:53
 [3] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [4] foreach at ./abstractarray.jl:1921 [inlined]
 [5] back_(::Tracker.Call{getfield(Tracker, Symbol("##509#510")){TrackedArray{,Adjoint{Float32,Array{Float32,2}}},TrackedArray{,Array{Float32,2}}},Tuple{Tracker.Tracked{Adjoint{Float32,Array{Float32,2}}},Tracker.Tracked{Array{Float32,2}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [6] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 ... (the last 4 lines are repeated 2 more times)
 [15] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Adjoint{Float32,Array{Float32,1}}}, ::Array{Float32,2}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [16] foreach(::Function, ::NTuple{4,Tracker.Tracked{Adjoint{Float32,Array{Float32,1}}}}, ::NTuple{4,Array{Float32,2}}) at ./abstractarray.jl:1921
 ... (the last 12 lines are repeated 1 more time)
 [29] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,typeof(*),Tuple{TrackedArray{,Array{Float32,3}},TrackedArray{,Array{Float32,3}}}},Tuple{Tracker.Tracked{Array{Float32,3}},Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [30] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [31] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [32] back_(::Tracker.Call{getfield(Tracker, Symbol("##482#483")){TrackedArray{,Array{Float32,3}}},Tuple{Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [33] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [34] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [35] foreach at ./abstractarray.jl:1921 [inlined]
 [36] back_(::Tracker.Call{getfield(Tracker, Symbol("##460#461")){TrackedArray{,Array{Float32,3}}},Tuple{Tracker.Tracked{Array{Float32,3}},Nothing}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [37] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [38] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [39] foreach(::Function, ::Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,2}}}, ::Tuple{Array{Float32,2},Array{Float32,2}}) at ./abstractarray.jl:1921
 ... (the last 8 lines are repeated 1 more time)
 [48] back_(::Tracker.Call{getfield(Tracker, Symbol("#back#548")){2,getfield(Base.Broadcast, Symbol("##2#4")){getfield(Base.Broadcast, Symbol("##8#10")){getfield(Base.Broadcast, Symbol("##1#3")),getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##5#6")){getfield(Base.Broadcast, Symbol("##7#9"))}},getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##11#12")){getfield(Base.Broadcast, Symbol("##13#14"))}},getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##15#16")){getfield(Base.Broadcast, Symbol("##17#18"))}},typeof(+)},typeof(identity)},Tuple{TrackedArray{,Array{Float32,2}},TrackedArray{,Array{Float32,1}}}},Tuple{Tracker.Tracked{Array{Float32,2}},Tracker.Tracked{Array{Float32,1}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [49] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [50] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [51] back_(::Tracker.Call{getfield(Tracker, Symbol("##513#514")){TrackedArray{,Array{Float32,2}}},Tuple{Tracker.Tracked{Array{Float32,2}}}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [52] back(::Tracker.Tracked{Array{Float32,2}}, ::Array{Float32,2}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [53] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [54] foreach at ./abstractarray.jl:1921 [inlined]
 [55] back_(::Tracker.Call{getfield(Tracker, Symbol("##460#461")){TrackedArray{,Array{Float32,2}}},Tuple{Tracker.Tracked{Array{Float32,2}},Nothing}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [56] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [57] (::getfield(Tracker, Symbol("##13#14")){Bool})(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [58] foreach(::Function, ::NTuple{32,Tracker.Tracked{Array{Float32,3}}}, ::NTuple{32,Array{Float32,3}}) at ./abstractarray.jl:1921
 [59] back_(::Tracker.Call{getfield(Tracker, Symbol("##450#455")){Int64,NTuple{32,TrackedArray{,Array{Float32,3}}}},NTuple{32,Tracker.Tracked{Array{Float32,3}}}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [60] back(::Tracker.Tracked{Array{Float32,3}}, ::Array{Float32,3}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [61] #13 at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 [62] foreach at ./abstractarray.jl:1921 [inlined]
 [63] back_(::Tracker.Call{getfield(Tracker, Symbol("##376#378")){TrackedArray{,Array{Float32,3}},Tuple{Array{Int64,1}}},Tuple{Tracker.Tracked{Array{Float32,3}},Nothing}}, ::Array{Float32,1}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [64] back(::Tracker.Tracked{Array{Float32,1}}, ::Array{Float32,1}, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [65] foreach at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38 [inlined]
 ... (the last 3 lines are repeated 1 more time)
 [69] back_(::Tracker.Call{getfield(Tracker, Symbol("##199#200")),Tuple{Tracker.Tracked{Float32}}}, ::Float32, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:38
 [70] back(::Tracker.Tracked{Float32}, ::Int64, ::Bool) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:58
 [71] back!(::Tracker.TrackedReal{Float32}) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:77
 [72] gradient_(::getfield(Flux.Optimise, Symbol("##14#20")){typeof(loss),Tuple{Array{Array{Float32,2},1},Array{Array{Int64,1},1}}}, ::Tracker.Params) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:4
 [73] #gradient#24(::Bool, ::typeof(Tracker.gradient), ::Function, ::Tracker.Params) at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:164
 [74] gradient at /Users/Azamat/.julia/packages/Tracker/SAr25/src/back.jl:164 [inlined]
 [75] macro expansion at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:71 [inlined]
 [76] macro expansion at /Users/Azamat/.julia/packages/Juno/oLB1d/src/progress.jl:119 [inlined]
 [77] #train!#12(::getfield(Flux.Optimise, Symbol("##16#22")), ::typeof(train!), ::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,2},1},1},Array{Array{Array{Int64,1},1},1}}}, ::ADAM) at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:69
 [78] train!(::Function, ::Tracker.Params, ::Base.Iterators.Zip{Tuple{Array{Array{Array{Float32,2},1},1},Array{Array{Array{Int64,1},1},1}}}, ::ADAM) at /Users/Azamat/.julia/packages/Flux/dkJUV/src/optimise/train.jl:67

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions