Skip to content

Commit 938da26

Browse files
authored
Merge pull request #44061 from BSnelling/bes/collect_broadcasted_2
Preserve shape when collecting broadcasted objects
2 parents 7e54f9a + 70fc3cd commit 938da26

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

base/broadcast.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()
244244

245245
Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = LinearIndices(axes(bc))::LinearIndices{1}
246246

247-
Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
247+
Base.ndims(bc::Broadcasted) = ndims(typeof(bc))
248248
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N
249249

250250
Base.size(bc::Broadcasted) = map(length, axes(bc))
@@ -261,7 +261,20 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
261261
return (bc[i], (s[1], newstate))
262262
end
263263

264-
Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()
264+
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
265+
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
266+
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
267+
268+
_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
269+
_maxndims(::Type{<:Tuple{T}}) where {T} = ndims(T)
270+
_maxndims(::Type{<:Tuple{T}}) where {T<:Tuple} = _ndims(T)
271+
function _maxndims(::Type{<:Tuple{T, S}}) where {T, S}
272+
return T<:Tuple || S<:Tuple ? max(_ndims(T), _ndims(S)) : max(ndims(T), ndims(S))
273+
end
274+
275+
_ndims(x) = ndims(x)
276+
_ndims(::Type{<:Tuple}) = 1
277+
265278
Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown()
266279

267280
## Instantiation fills in the "missing" fields in Broadcasted.

test/broadcast.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,39 @@ let
855855
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
856856
end
857857

858+
# issue 43847: collect preserves shape of broadcasted
859+
let
860+
bc = Broadcast.broadcasted(*, [1 2; 3 4], 2)
861+
@test collect(Iterators.product(bc, bc)) == collect(Iterators.product(copy(bc), copy(bc)))
862+
863+
a1 = AD1(rand(2,3))
864+
bc1 = Broadcast.broadcasted(*, a1, 2)
865+
@test collect(Iterators.product(bc1, bc1)) == collect(Iterators.product(copy(bc1), copy(bc1)))
866+
867+
# using ndims of second arg
868+
bc2 = Broadcast.broadcasted(*, 2, a1)
869+
@test collect(Iterators.product(bc2, bc2)) == collect(Iterators.product(copy(bc2), copy(bc2)))
870+
871+
# >2 args
872+
bc3 = Broadcast.broadcasted(*, a1, 3, a1)
873+
@test collect(Iterators.product(bc3, bc3)) == collect(Iterators.product(copy(bc3), copy(bc3)))
874+
875+
# including a tuple and custom array type
876+
bc4 = Broadcast.broadcasted(*, (1,2,3), AD1(rand(3)))
877+
@test collect(Iterators.product(bc4, bc4)) == collect(Iterators.product(copy(bc4), copy(bc4)))
878+
879+
# testing ArrayConflict
880+
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
881+
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
882+
883+
@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()
884+
885+
# inference on nested
886+
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
887+
bc_nest = Base.broadcasted(+, bc , bc)
888+
@test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}()
889+
end
890+
858891
# issue #31295
859892
let a = rand(5), b = rand(5), c = copy(a)
860893
view(identity(a), 1:3) .+= view(b, 1:3)

0 commit comments

Comments
 (0)