Skip to content

[Bug] Embedding forward pass breaks for onehotbatch with multiple batch dimensions #2160

Closed
@reachtarunhere

Description

@reachtarunhere

To reproduce:

using Flux: Embedding, onehotbatch, onehot
emb = Embedding(26, 2)
x1 = rand('a':'z', 10)
x1 = onehotbatch(x1, 'a':'z')
emb(x1) # works perfectly

# Breaking example below
x2 = rand('a':'z', (3, 10))
x2 = onehotbatch(x2, 'a':'z')
emb(x2) # breaks!

Error:

ArgumentError: invalid index: false of type Bool

to_index(i::Bool) at indices.jl:293
to_index(A::Matrix{Float32}, i::Bool) at indices.jl:277
to_indices at indices.jl:333 [inlined]
to_indices at multidimensional.jl:823 [inlined]
to_indices at indices.jl:324 [inlined]
view at subarray.jl:176 [inlined]
_view(X::Matrix{Float32}, colons::Tuple{Colon}, k::Bool) at scatter.jl:38
gather!(dst::Matrix{Float32}, src::Matrix{Float32}, idx::Base.ReshapedArray{Bool, 1, Flux.OneHotArray{UInt32, 26, 2, 3, Matrix{UInt32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}) at gather.jl:27
gather at gather.jl:77 [inlined]
Embedding at basic.jl:533 [inlined]
(::Flux.Embedding{Matrix{Float32}})(x::Flux.OneHotArray{UInt32, 26, 2, 3, Matrix{UInt32}}) at basic.jl:534
top-level scope at mlp.jl:41

The issue is because the second call dispatches to the wrong method

(m::Embedding)(x::AbstractArray{Bool}) = reshape(m(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

If it had dispatched to a method with similar approach of emb.weight * x2 it would work perfectly.
(m::Embedding)(x::AbstractMatrix{Bool}) = m.weight * x # usually OneHotMatrix

OneHotMatrix or tensors should be treated the same?

Thanks for reading!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions