Skip to content

collect doesn't preserve shape on Broadcased objects #43847

@oxinabox

Description

@oxinabox

consider a 2D generator

julia> g = (1 for x in 1:2, y in 2:3)
Base.Generator{Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}, var"#47#48"}(var"#47#48"(), Base.Iterators.ProductIterator{Tuple{UnitRange{Int64}, UnitRange{Int64}}}((1:2, 2:3)))

julia> size(g)
(2, 2)

julia> collect(g)
2×2 Matrix{Int64}:
 1  1
 1  1

All is well.

But for a Broadcasted object the shape is not preserved and collect gives back something flat.

julia> h = Base.broadcasted(sqrt, [1 2; 3 4])
Base.Broadcast.Broadcasted(sqrt, ([1 2; 3 4],))

julia> size(h)
(2, 2)

julia> collect(h)
4-element Vector{Float64}:
 1.0
 1.7320508075688772
 1.4142135623730951
 2.0

This can be fixed by materializing first

julia> collect(Base.materialize(h))
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

But why does it not work in the first place.
I suspect it is because the IteratorSize trait isn't set to HasShape.
But why isn't it?

julia> Base.IteratorSize(g)
Base.HasShape{2}()

julia> Base.IteratorSize(h)
Base.HasLength()

It seems to just be hitting the default fallback defintion of IteratorSize
rather than:

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()

We do have:

julia> typeof(h)
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{2}, Nothing, typeof(sqrt), Tuple{Matrix{Int64}}}

And that Base.Broadcast.DefaultArrayStyle{2} tells us what we need to know to define the HasShape

It seems like we could define

IteratorSize(::Type{<:Broadcasted{<:AbstractArrayStyle{N}}}) where {N} = HasShape{N}()

and indeed that does seem to work

julia> collect(h)
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

julia> Base.IteratorSize(::Type{<:Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{N}}}) where {N} = Base.HasShape{N}()

julia> collect(h)
2×2 Matrix{Float64}:
 1.0      1.41421
 1.73205  2.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    broadcastApplying a function over a collectioncollectionsData structures holding multiple items, e.g. sets

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions