From e4d1962808ff0a59bcf0cef0559ef554a073cf09 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 7 Jan 2018 06:19:36 -0600 Subject: [PATCH] Centralize broadcast support for structured matrices --- base/broadcast.jl | 3 +++ base/linalg/bidiag.jl | 27 ++++++++++++-------- base/linalg/diagonal.jl | 10 +++++++- base/linalg/linalg.jl | 2 ++ base/linalg/tridiag.jl | 47 +++++++++++++++++++---------------- base/sparse/higherorderfns.jl | 16 ++++++++++-- test/sparse/higherorderfns.jl | 2 +- 7 files changed, 71 insertions(+), 36 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 0a995022c76b8..feff37e5867b5 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -119,6 +119,9 @@ BroadcastStyle(::Type{<:Ref}) = DefaultArrayStyle{0}() # 3 or more arguments still return an `ArrayConflict`. struct ArrayConflict <: AbstractArrayStyle{Any} end +# This will be used for Diagonal, Bidiagonal, Tridiagonal, and SymTridiagonal +struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end + ### Binary BroadcastStyle rules """ BroadcastStyle(::Style1, ::Style2) = Style3() diff --git a/base/linalg/bidiag.jl b/base/linalg/bidiag.jl index 7c6ad8b1bd2eb..c2f1efd79ecac 100644 --- a/base/linalg/bidiag.jl +++ b/base/linalg/bidiag.jl @@ -172,7 +172,23 @@ Bidiagonal{T}(A::Bidiagonal) where {T} = # When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T} AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A) -broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo) +function copyto!(dest::Bidiagonal, bc::Broadcasted{PromoteToSparse}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + if dest.uplo == 'U' + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + else + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + end + dest +end # For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix. # On the other hand, similar(B, [neweltype,] shape...) should yield a sparse matrix. @@ -234,18 +250,9 @@ function size(M::Bidiagonal, d::Integer) end #Elementary operations -broadcast(::typeof(abs), M::Bidiagonal) = Bidiagonal(abs.(M.dv), abs.(M.ev), M.uplo) -broadcast(::typeof(round), M::Bidiagonal) = Bidiagonal(round.(M.dv), round.(M.ev), M.uplo) -broadcast(::typeof(trunc), M::Bidiagonal) = Bidiagonal(trunc.(M.dv), trunc.(M.ev), M.uplo) -broadcast(::typeof(floor), M::Bidiagonal) = Bidiagonal(floor.(M.dv), floor.(M.ev), M.uplo) -broadcast(::typeof(ceil), M::Bidiagonal) = Bidiagonal(ceil.(M.dv), ceil.(M.ev), M.uplo) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo) end -broadcast(::typeof(round), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(round.(T, M.dv), round.(T, M.ev), M.uplo) -broadcast(::typeof(trunc), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(trunc.(T, M.dv), trunc.(T, M.ev), M.uplo) -broadcast(::typeof(floor), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(floor.(T, M.dv), floor.(T, M.ev), M.uplo) -broadcast(::typeof(ceil), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(ceil.(T, M.dv), ceil.(T, M.ev), M.uplo) transpose(M::Bidiagonal) = Bidiagonal(M.dv, M.ev, M.uplo == 'U' ? :L : :U) adjoint(M::Bidiagonal) = Bidiagonal(conj(M.dv), conj(M.ev), M.uplo == 'U' ? :L : :U) diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index 25600afa2293f..6de2416f8262c 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -111,10 +111,18 @@ isposdef(D::Diagonal) = all(x -> x > 0, D.diag) factorize(D::Diagonal) = D -broadcast(::typeof(abs), D::Diagonal) = Diagonal(abs.(D.diag)) real(D::Diagonal) = Diagonal(real(D.diag)) imag(D::Diagonal) = Diagonal(imag(D.diag)) +function copyto!(dest::Diagonal, bc::Broadcasted{PromoteToSparse}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.diag[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + dest +end + istriu(D::Diagonal) = true istril(D::Diagonal) = true function triu!(D::Diagonal,k::Integer=0) diff --git a/base/linalg/linalg.jl b/base/linalg/linalg.jl index 99d308b764426..7d2189dd0e672 100644 --- a/base/linalg/linalg.jl +++ b/base/linalg/linalg.jl @@ -17,6 +17,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof, @propagate_inbounds, @pure, reduce, typed_vcat +using Base.Broadcast: Broadcasted, PromoteToSparse + # We use `_length` because of non-1 indices; releases after julia 0.5 # can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`. diff --git a/base/linalg/tridiag.jl b/base/linalg/tridiag.jl index 0bc8448e50c0d..388f8c7ae9cf6 100644 --- a/base/linalg/tridiag.jl +++ b/base/linalg/tridiag.jl @@ -113,19 +113,22 @@ end similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T)) similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) +function copyto!(dest::SymTridiagonal, bc::Broadcasted{PromoteToSparse}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + dest +end + #Elementary operations -broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev)) -broadcast(::typeof(round), M::SymTridiagonal) = SymTridiagonal(round.(M.dv), round.(M.ev)) -broadcast(::typeof(trunc), M::SymTridiagonal) = SymTridiagonal(trunc.(M.dv), trunc.(M.ev)) -broadcast(::typeof(floor), M::SymTridiagonal) = SymTridiagonal(floor.(M.dv), floor.(M.ev)) -broadcast(::typeof(ceil), M::SymTridiagonal) = SymTridiagonal(ceil.(M.dv), ceil.(M.ev)) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev)) end -broadcast(::typeof(round), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(round.(T, M.dv), round.(T, M.ev)) -broadcast(::typeof(trunc), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(trunc.(T, M.dv), trunc.(T, M.ev)) -broadcast(::typeof(floor), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(floor.(T, M.dv), floor.(T, M.ev)) -broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(ceil.(T, M.dv), ceil.(T, M.ev)) transpose(M::SymTridiagonal) = M #Identity operation adjoint(M::SymTridiagonal) = conj(M) @@ -500,24 +503,11 @@ similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spz copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest) #Elementary operations -broadcast(::typeof(abs), M::Tridiagonal) = Tridiagonal(abs.(M.dl), abs.(M.d), abs.(M.du)) -broadcast(::typeof(round), M::Tridiagonal) = Tridiagonal(round.(M.dl), round.(M.d), round.(M.du)) -broadcast(::typeof(trunc), M::Tridiagonal) = Tridiagonal(trunc.(M.dl), trunc.(M.d), trunc.(M.du)) -broadcast(::typeof(floor), M::Tridiagonal) = Tridiagonal(floor.(M.dl), floor.(M.d), floor.(M.du)) -broadcast(::typeof(ceil), M::Tridiagonal) = Tridiagonal(ceil.(M.dl), ceil.(M.d), ceil.(M.du)) for func in (:conj, :copy, :real, :imag) @eval function ($func)(M::Tridiagonal) Tridiagonal(($func)(M.dl), ($func)(M.d), ($func)(M.du)) end end -broadcast(::typeof(round), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(round.(T, M.dl), round.(T, M.d), round.(T, M.du)) -broadcast(::typeof(trunc), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(trunc.(T, M.dl), trunc.(T, M.d), trunc.(T, M.du)) -broadcast(::typeof(floor), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(floor.(T, M.dl), floor.(T, M.d), floor.(T, M.du)) -broadcast(::typeof(ceil), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(ceil.(T, M.dl), ceil.(T, M.d), ceil.(T, M.du)) transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl) adjoint(M::Tridiagonal) = conj(transpose(M)) @@ -576,6 +566,19 @@ function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::Ab i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s) end +function copyto!(dest::Tridiagonal, bc::Broadcasted{PromoteToSparse}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.d[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.du[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + dest.dl[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + dest +end + #tril and triu istriu(M::Tridiagonal) = iszero(M.dl) diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 18c2f73e73e98..67c66fb62fba9 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -9,7 +9,7 @@ import Base: map, map!, broadcast, copy, copyto! using Base: TupleLL, TupleLLEnd, front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange -using Base.Broadcast: BroadcastStyle, Broadcasted, flatten +using Base.Broadcast: BroadcastStyle, Broadcasted, PromoteToSparse, Args1, Args2, flatten # This module is organized as follows: # (0) Define BroadcastStyle rules and convenience types for dispatch @@ -54,7 +54,6 @@ SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle() -struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse() @@ -969,6 +968,7 @@ function _copy(::Any, bc::Broadcasted{<:SPVM}) parevalf, passedargstup = capturescalars(bcf.f, args) return broadcast(parevalf, passedargstup...) end + function _shapecheckbc(bc::Broadcasted) args = Tuple(bc.args) _aresameshape(bc.args) ? _noshapecheck_map(bc.f, args...) : _diffshape_broadcast(bc.f, args...) @@ -1044,10 +1044,22 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( # and rebroadcast. otherwise, divert to generic AbstractArray broadcast code. function copy(bc::Broadcasted{PromoteToSparse}) + if bc.args isa Args1{<:StructuredMatrix} || bc.args isa Args2{<:Type,<:StructuredMatrix} + if _iszero(fzero(bc.f, bc.args)) + T = Broadcast.combine_eltypes(bc.f, bc.args) + M = get_matrix(bc.args) + dest = similar(M, T) + return copyto!(dest, bc) + end + end bcf = flatten(bc) As = Tuple(bcf.args) broadcast(bcf.f, map(_sparsifystructured, As)...) end +get_matrix(args::Args1{<:StructuredMatrix}) = args.head +get_matrix(args::Args2{<:Type,<:StructuredMatrix}) = args.rest.head +fzero(f::Tf, args::Args1{<:StructuredMatrix}) where Tf = f(zero(eltype(get_matrix(args)))) +fzero(f::Tf, args::Args2{<:Type, <:StructuredMatrix}) where Tf = f(args.head, zero(eltype(get_matrix(args)))) function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse}) bcf = flatten(bc) diff --git a/test/sparse/higherorderfns.jl b/test/sparse/higherorderfns.jl index b2b0ea46524fe..fb408efb1f416 100644 --- a/test/sparse/higherorderfns.jl +++ b/test/sparse/higherorderfns.jl @@ -382,7 +382,7 @@ end structuredarrays = (D, B, T, S) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(sin, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(sin, fX))) + @test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == sparse(broadcast(sin, fX))) @test broadcast!(sin, Z, X) == sparse(broadcast(sin, fX)) @test (Q = broadcast(cos, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(cos, fX))) @test broadcast!(cos, Z, X) == sparse(broadcast(cos, fX))