Skip to content

Commit

Permalink
LinearAlgebra.givens with unitful arguments (#36430)
Browse files Browse the repository at this point in the history
  • Loading branch information
KlausC authored Jul 3, 2020
1 parent f9496fb commit b79a6d1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
15 changes: 14 additions & 1 deletion stdlib/LinearAlgebra/src/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end
convert(::Type{T}, r::T) where {T<:AbstractRotation} = r
convert(::Type{T}, r::AbstractRotation) where {T<:AbstractRotation} = T(r)

Givens(i1, i2, c, s) = Givens(i1, i2, promote(c, s)...)
Givens{T}(G::Givens{T}) where {T} = G
Givens{T}(G::Givens) where {T} = Givens(G.i1, G.i2, convert(T, G.c), convert(T, G.s))
Rotation{T}(R::Rotation{T}) where {T} = R
Expand Down Expand Up @@ -248,6 +249,18 @@ function givensAlgorithm(f::Complex{T}, g::Complex{T}) where T<:AbstractFloat
return cs, sn, r
end

# enable for unitful quantities
function givensAlgorithm(f::T, g::T) where T
fs = f / oneunit(T)
gs = g / oneunit(T)
typeof(fs) === T && typeof(gs) === T &&
!isa(fs, Union{AbstractFloat,Complex{<:AbstractFloat}}) &&
throw(MethodError(givensAlgorithm, (fs, gs)))

c, s, r = givensAlgorithm(fs, gs)
return c, s, r * oneunit(T)
end

givensAlgorithm(f, g) = givensAlgorithm(promote(float(f), float(g))...)

"""
Expand Down Expand Up @@ -280,7 +293,7 @@ function givens(f::T, g::T, i1::Integer, i2::Integer) where T
s = -conj(s)
i1,i2 = i2,i1
end
Givens(i1, i2, convert(T, c), convert(T, s)), r
Givens(i1, i2, c, s), r
end
"""
givens(A::AbstractArray, i1::Integer, i2::Integer, j::Integer) -> (G::Givens, r)
Expand Down
32 changes: 31 additions & 1 deletion stdlib/LinearAlgebra/test/givens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module TestGivens

using Test, LinearAlgebra, Random
using LinearAlgebra: rmul!, lmul!
using LinearAlgebra: rmul!, lmul!, Givens

# Test givens rotations
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
Expand Down Expand Up @@ -70,4 +70,34 @@ using LinearAlgebra: rmul!, lmul!
end
end

# 36430
# dimensional correctness:
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :Furlongs) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Furlongs.jl"))
using .Main.Furlongs

@testset "testing dimensions with Furlongs" begin
@test_throws MethodError givens(Furlong(1.0), Furlong(2.0), 1, 2)
end

const TNumber = Union{Float64,ComplexF64}
struct MockUnitful{T<:TNumber} <: Number
data::T
MockUnitful(data::T) where T<:TNumber = new{T}(data)
end
import Base: *, /, one, oneunit
*(a::MockUnitful{T}, b::T) where T<:TNumber = MockUnitful(a.data * b)
*(a::T, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a * b.data)
*(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = MockUnitful(a.data * b.data)
/(a::MockUnitful{T}, b::MockUnitful{T}) where T<:TNumber = a.data / b.data
one(::Type{<:MockUnitful{T}}) where T = one(T)
oneunit(::Type{<:MockUnitful{T}}) where T = MockUnitful(one(T))

@testset "unitful givens rotation unitful $T " for T in (Float64, ComplexF64)
g, r = givens(MockUnitful(T(3)), MockUnitful(T(4)), 1, 2)
@test g.c 3/5
@test g.s 4/5
@test r.data 5.0
end

end # module TestGivens

0 comments on commit b79a6d1

Please sign in to comment.