Skip to content

Commit

Permalink
LinearAlgebra: Type-stability in broadcasting numbers over Bidiagonal (
Browse files Browse the repository at this point in the history
…#54067)

This makes the following type-stable:
```julia
julia> B = Bidiagonal(rand(3), rand(2), :U);

julia> @inferred (B -> B .* 2)(B)
3×3 Bidiagonal{Float64, Vector{Float64}}:
 0.3929  1.93165   ⋅
  ⋅      1.61301  1.00202
  ⋅       ⋅       1.96483
```
Similarly, for other operations involving a single `Bidiagonal` and
numbers. This is not type-stable on master, as the number of
`Bidiagonal` matrices in a broadcast operation is not tracked (even
though this is used in promoting the `uplo`). Since the `uplo` can't be
constant-propagated, we count this by introducing an additional flag in
the promotion mechanism, which is entirely determined by the types of
the terms in the broadcast operation.

---------

Co-authored-by: N5N3 <2642243996@qq.com>
(cherry picked from commit 685f527)
  • Loading branch information
jishnub authored and KristofferC committed May 6, 2024
1 parent c7bfbec commit bc9c928
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
4 changes: 3 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nest
function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
uplo = n > 0 ? find_uplo(bc) : 'U'
n1 = max(n - 1, 0)
if uplo == 'T'
if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
Expand Down Expand Up @@ -134,6 +134,8 @@ iszerodefined(::Type) = false
iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)

count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)

fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
# expression is stable. We can test the zero-preservability by applying the function
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ using Test, LinearAlgebra
@test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY)
end
end

@testset "type-stability in Bidiagonal" begin
B2 = @inferred (B -> .- B)(B)
@test B2 isa Bidiagonal
@test B2 == -1 * B
B2 = @inferred (B -> B .* 2)(B)
@test B2 isa Bidiagonal
@test B2 == B + B
B2 = @inferred (B -> 2 .* B)(B)
@test B2 isa Bidiagonal
@test B2 == B + B
B2 = @inferred (B -> B ./ 1)(B)
@test B2 isa Bidiagonal
@test B2 == B
B2 = @inferred (B -> 1 .\ B)(B)
@test B2 isa Bidiagonal
@test B2 == B
end
end

@testset "broadcast! where the destination is a structured matrix" begin
Expand Down

0 comments on commit bc9c928

Please sign in to comment.