-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Closed
Labels
broadcastApplying a function over a collectionApplying a function over a collectioncollectionsData structures holding multiple items, e.g. setsData structures holding multiple items, e.g. sets
Description
consider a 2D generator
julia> g = (1 for x in 1:2, y in 2:3)
Base.Generator{Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, var"#47#48"}(var"#47#48"(), Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}((1:2, 2:3)))
julia> size(g)
(2, 2)
julia> collect(g)
2×2 Matrix{Int64}:
1 1
1 1
All is well.
But for a Broadcasted
object the shape is not preserved and collect
gives back something flat.
julia> h = Base.broadcasted(sqrt, [1 2; 3 4])
Base.Broadcast.Broadcasted(sqrt, ([1 2; 3 4],))
julia> size(h)
(2, 2)
julia> collect(h)
4-element Vector{Float64}:
1.0
1.7320508075688772
1.4142135623730951
2.0
This can be fixed by materializing
first
julia> collect(Base.materialize(h))
2×2 Matrix{Float64}:
1.0 1.41421
1.73205 2.0
But why does it not work in the first place.
I suspect it is because the IteratorSize
trait isn't set to HasShape
.
But why isn't it?
julia> Base.IteratorSize(g)
Base.HasShape{2}()
julia> Base.IteratorSize(h)
Base.HasLength()
It seems to just be hitting the default fallback defintion of IteratorSize
rather than:
Line 264 in 591f066
Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}() |
We do have:
julia> typeof(h)
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(sqrt), Tuple{Matrix{Int64}}}
And that Base.Broadcast.DefaultArrayStyle{2}
tells us what we need to know to define the HasShape
It seems like we could define
IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}}}) where {N} = HasShape{N}()
and indeed that does seem to work
julia> collect(h)
2×2 Matrix{Float64}:
1.0 1.41421
1.73205 2.0
julia> Base.IteratorSize(::Type{<:Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{N}}}) where {N} = Base.HasShape{N}()
julia> collect(h)
2×2 Matrix{Float64}:
1.0 1.41421
1.73205 2.0
Metadata
Metadata
Assignees
Labels
broadcastApplying a function over a collectionApplying a function over a collectioncollectionsData structures holding multiple items, e.g. setsData structures holding multiple items, e.g. sets