Skip to content

_fill_dot support general vectors #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,38 +160,22 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::ZerosVector{T}) where T<:Rea
end
*(a::Transpose{T, <:AbstractMatrix{T}}, b::ZerosVector{T}) where T<:Real = mult_zeros(a, b)

# treat zero separately to support ∞-vectors
function _zero_dot(a, b)
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
zero(promote_type(eltype(a),eltype(b)))
end

_fill_dot(a::Zeros, b::Zeros) = _zero_dot(a, b)
_fill_dot(a::Zeros, b) = _zero_dot(a, b)
_fill_dot(a, b::Zeros) = _zero_dot(a, b)
_fill_dot(a::Zeros, b::AbstractFill) = _zero_dot(a, b)
_fill_dot(a::AbstractFill, b::Zeros) = _zero_dot(a, b)

function _fill_dot(a::AbstractFill, b::AbstractFill)
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
getindex_value(a)getindex_value(b)*length(b)
end

# support types with fast sum
function _fill_dot(a::AbstractFill, b)
# infinite cases should be supported in InfiniteArrays.jl
# type issues of Bool dot are ignored at present.
function _fill_dot(a::AbstractFillVector{T}, b::AbstractVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
getindex_value(a)sum(b)
dot(getindex_value(a), sum(b))
end

function _fill_dot(a, b::AbstractFill)
function _fill_dot_rev(a::AbstractVector{T}, b::AbstractFillVector{V}) where {T,V}
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
sum(a)getindex_value(b)
dot(sum(a), getindex_value(b))
end


dot(a::AbstractFillVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractFillVector, b::AbstractVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot(a, b)
dot(a::AbstractVector, b::AbstractFillVector) = _fill_dot_rev(a, b)

function dot(u::AbstractVector, E::Eye, v::AbstractVector)
length(u) == size(E,1) && length(v) == size(E,2) ||
Expand Down
62 changes: 38 additions & 24 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,32 @@ end
# type, and produce numerically correct results.
as_array(x::AbstractArray) = Array(x)
as_array(x::UniformScaling) = x
function test_addition_and_subtraction(As, Bs, Tout::Type)
equal_or_undef(a::Number, b::Number) = (a == b) || isequal(a, b)
equal_or_undef(a, b) = all(equal_or_undef.(a, b))
function test_addition_subtraction_dot(As, Bs, Tout::Type)
for A in As, B in Bs
@testset "$(typeof(A)) ± $(typeof(B))" begin
@testset "$(typeof(A)) and $(typeof(B))" begin
@test A + B isa Tout{promote_type(eltype(A), eltype(B))}
@test as_array(A + B) == as_array(A) + as_array(B)
@test equal_or_undef(as_array(A + B), as_array(A) + as_array(B))

@test A - B isa Tout{promote_type(eltype(A), eltype(B))}
@test as_array(A - B) == as_array(A) - as_array(B)
@test equal_or_undef(as_array(A - B), as_array(A) - as_array(B))

@test B + A isa Tout{promote_type(eltype(B), eltype(A))}
@test as_array(B + A) == as_array(B) + as_array(A)
@test equal_or_undef(as_array(B + A), as_array(B) + as_array(A))

@test B - A isa Tout{promote_type(eltype(B), eltype(A))}
@test as_array(B - A) == as_array(B) - as_array(A)
@test equal_or_undef(as_array(B - A), as_array(B) - as_array(A))

# Julia 1.6 doesn't support dot(UniformScaling)
if VERSION < v"1.6.0" || VERSION >= v"1.8.0"
d1 = dot(A, B)
d2 = dot(as_array(A), as_array(B))
d3 = dot(B, A)
d4 = dot(as_array(B), as_array(A))
@test d1 ≈ d2 || d1 ≡ d2
@test d3 ≈ d4 || d3 ≡ d4
end
end
end
end
Expand Down Expand Up @@ -349,37 +361,37 @@ end
@test -A_fill === Fill(-A_fill.value, 5)

# FillArray +/- FillArray should construct a new FillArray.
test_addition_and_subtraction((A_fill, B_fill), (A_fill, B_fill), Fill)
test_addition_subtraction_dot((A_fill, B_fill), (A_fill, B_fill), Fill)
test_addition_and_subtraction_dim_mismatch(A_fill, Fill(randn(rng), 5, 2))

# FillArray + Array (etc) should construct a new Array using `getindex`.
A_dense, B_dense = randn(rng, 5), [5, 4, 3, 2, 1]
test_addition_and_subtraction((A_fill, B_fill), (A_dense, B_dense), Array)
B_dense = (randn(rng, 5), [5, 4, 3, 2, 1], fill(Inf, 5), fill(NaN, 5))
test_addition_subtraction_dot((A_fill, B_fill), B_dense, Array)
test_addition_and_subtraction_dim_mismatch(A_fill, randn(rng, 5, 2))

# FillArray + StepLenRange / UnitRange (etc) should yield an AbstractRange.
A_ur, B_ur = 1.0:5.0, 6:10
test_addition_and_subtraction((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
test_addition_subtraction_dot((A_fill, B_fill), (A_ur, B_ur), AbstractRange)
test_addition_and_subtraction_dim_mismatch(A_fill, 1.0:6.0)
test_addition_and_subtraction_dim_mismatch(A_fill, 5:10)

# FillArray + UniformScaling should yield a Matrix in general
As_fill_square = (Fill(randn(rng, Float64), 3, 3), Fill(5, 4, 4))
Bs_us = (UniformScaling(2.3), UniformScaling(3))
test_addition_and_subtraction(As_fill_square, Bs_us, Matrix)
test_addition_subtraction_dot(As_fill_square, Bs_us, Matrix)
As_fill_nonsquare = (Fill(randn(rng, Float64), 3, 2), Fill(5, 3, 4))
for A in As_fill_nonsquare, B in Bs_us
test_addition_and_subtraction_dim_mismatch(A, B)
end

# FillArray + StaticArray should not have ambiguities
A_svec, B_svec = SVector{5}(rand(5)), SVector(1, 2, 3, 4, 5)
test_addition_and_subtraction((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})
test_addition_subtraction_dot((A_fill, B_fill, Zeros(5)), (A_svec, B_svec), SVector{5})

# Issue #224
A_matmat, B_matmat = Fill(rand(3,3),5), [rand(3,3) for n=1:5]
test_addition_and_subtraction((A_matmat,), (A_matmat,), Fill)
test_addition_and_subtraction((B_matmat,), (A_matmat,), Vector)
test_addition_subtraction_dot((A_matmat,), (A_matmat,), Fill)
test_addition_subtraction_dot((B_matmat,), (A_matmat,), Vector)

# Optimizations for Zeros and RectOrDiagonal{<:Any, <:AbstractFill}
As_special_square = (
Expand All @@ -389,7 +401,7 @@ end
RectDiagonal(Fill(randn(rng, Float64), 3), 3, 3), RectDiagonal(Fill(3, 4), 4, 4)
)
DiagonalAbstractFill{T} = Diagonal{T, <:AbstractFill{T, 1}}
test_addition_and_subtraction(As_special_square, Bs_us, DiagonalAbstractFill)
test_addition_subtraction_dot(As_special_square, Bs_us, DiagonalAbstractFill)
As_special_nonsquare = (
Zeros(3, 2), Zeros{Int}(3, 4),
Eye(3, 2), Eye{Int}(3, 4),
Expand Down Expand Up @@ -514,7 +526,7 @@ end
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
@test_throws DimensionMismatch randn(4)' * Zeros(3)
@test Zeros(5)' * randn(5,3) ≡ Zeros(5)'*Zeros(5,3) ≡ Zeros(5)'*Ones(5,3) ≡ Zeros(3)'
@test Zeros(5)' * randn(5) ≡ Zeros(5)' * Zeros(5) ≡ Zeros(5)' * Ones(5) ≡ 0.0
@test abs(Zeros(5)' * randn(5))abs(Zeros(5)' * Zeros(5))abs(Zeros(5)' * Ones(5)) ≡ 0.0
@test Zeros(5) * Zeros(6)' ≡ Zeros(5,1) * Zeros(6)' ≡ Zeros(5,6)
@test randn(5) * Zeros(6)' ≡ randn(5,1) * Zeros(6)' ≡ Zeros(5,6)
@test Zeros(5) * randn(6)' ≡ Zeros(5,6)
Expand All @@ -529,7 +541,7 @@ end
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
@test transpose(Zeros(5)) * randn(5,3) ≡ transpose(Zeros(5))*Zeros(5,3) ≡ transpose(Zeros(5))*Ones(5,3) ≡ transpose(Zeros(3))
@test transpose(Zeros(5)) * randn(5) ≡ transpose(Zeros(5)) * Zeros(5) ≡ transpose(Zeros(5)) * Ones(5) ≡ 0.0
@test abs(transpose(Zeros(5)) * randn(5))abs(transpose(Zeros(5)) * Zeros(5))abs(transpose(Zeros(5)) * Ones(5)) ≡ 0.0
@test randn(5) * transpose(Zeros(6)) ≡ randn(5,1) * transpose(Zeros(6)) ≡ Zeros(5,6)
@test Zeros(5) * transpose(randn(6)) ≡ Zeros(5,6)
@test transpose(randn(5)) * Zeros(5) ≡ 0.0
Expand All @@ -547,13 +559,13 @@ end
@test +(z1) === z1
@test -(z1) === z1

test_addition_and_subtraction((z1, z2), (z1, z2), Zeros)
test_addition_subtraction_dot((z1, z2), (z1, z2), Zeros)
test_addition_and_subtraction_dim_mismatch(z1, Zeros{Float64}(4, 2))
end

# `Zeros` +/- `Fill`s should yield `Fills`.
fill1, fill2 = Fill(5.0, 4), Fill(5, 4)
test_addition_and_subtraction((z1, z2), (fill1, fill2), Fill)
test_addition_subtraction_dot((z1, z2), (fill1, fill2), Fill)
test_addition_and_subtraction_dim_mismatch(z1, Fill(5, 5))

X = randn(3, 5)
Expand Down Expand Up @@ -1291,17 +1303,19 @@ end
Random.seed!(5)
u = rand(n)
v = rand(n)
c = rand(ComplexF16, n)

@test dot(u, D, v) == dot(u, v)
@test dot(u, 2D, v) == 2dot(u, v)
@test dot(u, Z, v) == 0

@test dot(Zeros(5), Zeros{ComplexF16}(5)) ≡ zero(ComplexF64)
@test dot(Zeros(5), Ones{ComplexF16}(5)) ≡ zero(ComplexF64)
@test dot(Ones{ComplexF16}(5), Zeros(5)) ≡ zero(ComplexF64)
@test dot(randn(5), Zeros{ComplexF16}(5)) ≡ dot(Zeros{ComplexF16}(5), randn(5)) ≡ zero(ComplexF64)
@test @inferred(dot(Zeros(5), Zeros{ComplexF16}(5))) ≡ zero(ComplexF64)
@test @inferred(dot(Zeros(5), Ones{ComplexF16}(5))) ≡ zero(ComplexF64)
@test abs(@inferred(dot(Ones{ComplexF16}(5), Zeros(5))))abs(@inferred(dot(randn(5), Zeros{ComplexF16}(5)))) ≡ abs(@inferred(dot(Zeros{ComplexF16}(5), randn(5)))) ≡ zero(Float64) # 0.0 !≡ -0.0
@test @inferred(dot(c, Fill(1 + im, 15)))(@inferred(dot(Fill(1 + im, 15), c)))' ≈ @inferred(dot(c, fill(1 + im, 15)))

@test dot(Fill(1,5), Fill(2.0,5)) ≡ 10.0
@test @inferred(dot(Fill(1,5), Fill(2.0,5))) ≡ 10.0
@test_skip dot(Fill(true,5), Fill(Int8(1),5)) isa Int8 # not working at present

let N = 2^big(1000) # fast dot for fast sum
@test dot(Fill(2,N),1:N) == dot(Fill(2,N),1:N) == dot(1:N,Fill(2,N)) == 2*sum(1:N)
Expand Down