Skip to content

Commit

Permalink
take 1
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 13, 2021
1 parent 74c4bf6 commit e95f3e8
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/Unitful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ include("logarithm.jl")
include("complex.jl")
include("pkgdefaults.jl")
include("dates.jl")
include("linearalgebra.jl")

end
42 changes: 42 additions & 0 deletions src/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using LinearAlgebra

# This function is re-defined during testing, to check we hit the fast path:
linearalgebra_count() = nothing

function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::StridedMatrix{<:AbstractQuantity{T}},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}
# This is exactly how A * B creates C = similar(B, T, ...)
eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
C0 = ustrip(C)
A0 = ustrip(A)
B0 = ustrip(B)
mul!(C0, A0, B0)
linearalgebra_count()
return C
end

function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
A::LinearAlgebra.AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix},
B::StridedVecOrMat{<:AbstractQuantity{T}},
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}

eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
C0 = ustrip(C)
A0 = A isa Adjoint ? adjoint(ustrip(parent(A))) : transpose(ustrip(parent(A)))
B0 = ustrip(B)
mul!(C0, A0, B0)
linearalgebra_count()
return C
end

function LinearAlgebra.dot(A::StridedArray{<:AbstractQuantity{T}},
B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
A0 = ustrip(A)
B0 = ustrip(B)
C0 = dot(A0, B0)
linearalgebra_count()
C = C0 * oneunit(eltype(A)) * oneunit(eltype(B)) # surely there is an official way
return C
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ julia> a[1] = 3u"m"; b
2
```
"""
@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)
@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)

@deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A))

Expand Down
34 changes: 34 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,40 @@ end
@test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m
end

@testset "LinearAlgebra functions" begin
CNT = Ref(0)
Unitful.linearalgebra_count() = (CNT[] += 1; nothing)
@testset "> Matrix multiplication: *" begin
M = rand(3,3) .* u"m"
M_ = view(M,:,1:3)
v = rand(3) .* u"V"
v_ = view(v, 1:3)

CNT[] = 0

@test unit(first(M * M)) == u"m*m"
@test M * M == M_ * M == M * M_ == M_ * M_

@test unit(first(M * v)) == u"m*V"
@test M * v == M_ * v == M * v_ == M_ * v_

@test CNT[] == 10

@test unit(first(v' * M)) == u"m*V"
@test v' * M == v_' * M == v_' * M == v_' * M_

@test CNT[] == 15

@test unit(v' * v) == u"V*V"
@test v' * v == v_' * v == v_' * v == v_' * v_

@test CNT[] == 20
end
@testset "> Matrix multiplication: mul!" begin

end
end

@testset "Types" begin
@test Base.complex(Quantity{Float64,NoDims,NoUnits}) ==
Quantity{Complex{Float64},NoDims,NoUnits}
Expand Down

0 comments on commit e95f3e8

Please sign in to comment.