Closed
Description
To reproduce:
using Flux: Embedding, onehotbatch, onehot
emb = Embedding(26, 2)
x1 = rand('a':'z', 10)
x1 = onehotbatch(x1, 'a':'z')
emb(x1) # works perfectly
# Breaking example below
x2 = rand('a':'z', (3, 10))
x2 = onehotbatch(x2, 'a':'z')
emb(x2) # breaks!
Error:
ArgumentError: invalid index: false of type Bool
to_index(i::Bool) at indices.jl:293
to_index(A::Matrix{Float32}, i::Bool) at indices.jl:277
to_indices at indices.jl:333 [inlined]
to_indices at multidimensional.jl:823 [inlined]
to_indices at indices.jl:324 [inlined]
view at subarray.jl:176 [inlined]
_view(X::Matrix{Float32}, colons::Tuple{Colon}, k::Bool) at scatter.jl:38
gather!(dst::Matrix{Float32}, src::Matrix{Float32}, idx::Base.ReshapedArray{Bool, 1, Flux.OneHotArray{UInt32, 26, 2, 3, Matrix{UInt32}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}) at gather.jl:27
gather at gather.jl:77 [inlined]
Embedding at basic.jl:533 [inlined]
(::Flux.Embedding{Matrix{Float32}})(x::Flux.OneHotArray{UInt32, 26, 2, 3, Matrix{UInt32}}) at basic.jl:534
top-level scope at mlp.jl:41
The issue is because the second call dispatches to the wrong method
Line 702 in 7ae67a3
If it had dispatched to a method with similar approach of emb.weight * x2 it would work perfectly.
Line 701 in 7ae67a3
OneHotMatrix or tensors should be treated the same?
Thanks for reading!
Metadata
Metadata
Assignees
Labels
No labels