Skip to content

Make LU factorization work for more types #26344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,25 @@ function lu(A::Union{AbstractMatrix{T}, AbstractMatrix{Complex{T}}},
lu!(copy(A), pivot)
end

function lutype(T::Type)
# In generic_lufact!, the elements of the lower part of the matrix are
# obtained using the division of two matrix elements. Hence their type can
# be different (e.g. the division of two types with the same unit is a type
# without unit).
# The elements of the upper part are obtained by U - U * L
# where U is an upper part element and L is a lower part element.
# Therefore, the types LT, UT should be invariant under the map:
# (LT, UT) -> begin
# L = oneunit(UT) / oneunit(UT)
# U = oneunit(UT) - oneunit(UT) * L
# typeof(L), typeof(U)
# end
# The following should handle most cases
UT = typeof(oneunit(T) - oneunit(T) * (oneunit(T) / (oneunit(T) + zero(T))))
LT = typeof(oneunit(UT) / oneunit(UT))
S = promote_type(T, LT, UT)
end

# for all other types we must promote to a type which is stable under division
"""
lu(A, pivot=Val(true)) -> F::LU
Expand Down Expand Up @@ -191,14 +210,14 @@ true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}) where T
S = typeof(zero(T)/one(T))
S = lutype(T)
AA = similar(A, S)
copyto!(AA, A)
lu!(AA, pivot)
end
# We can't assume an ordered field so we first try without pivoting
function lu(A::AbstractMatrix{T}) where T
S = typeof(zero(T)/one(T))
S = lutype(T)
AA = similar(A, S)
copyto!(AA, A)
F = lu!(AA, Val(false))
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,15 @@ end
@test allnames == ["L", "P", "U", "factors", "info", "ipiv", "p"]
end

include("trickyarithmetic.jl")

@testset "lu with type whose sum is another type" begin
A = TrickyArithmetic.A[1 2; 3 4]
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
B = lu(A)
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
C = lu(A, Val(false))
@test C isa LinearAlgebra.LU{ElT,Matrix{ElT}}
end

end # module TestLU
60 changes: 60 additions & 0 deletions stdlib/LinearAlgebra/test/trickyarithmetic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module TrickyArithmetic
struct A
x::Int
end
A(a::A) = a
Base.convert(::Type{A}, i::Int) = A(i)
Base.zero(::Union{A, Type{A}}) = A(0)
Base.one(::Union{A, Type{A}}) = A(1)
struct B
x::Int
end
struct C
x::Int
end
C(a::A) = C(a.x)
Base.zero(::Union{C, Type{C}}) = C(0)
Base.one(::Union{C, Type{C}}) = C(1)

Base.:(*)(x::Int, a::A) = B(x*a.x)
Base.:(*)(a::A, x::Int) = B(a.x*x)
Base.:(*)(a::Union{A,B}, b::Union{A,B}) = B(a.x*b.x)
Base.:(*)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x*b.x)
Base.:(+)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x+b.x)
Base.:(-)(a::Union{A,B,C}, b::Union{A,B,C}) = C(a.x-b.x)

struct D{NT, DT}
n::NT
d::DT
end
D{NT, DT}(d::D{NT, DT}) where {NT, DT} = d # called by oneunit
Base.zero(::Union{D{NT, DT}, Type{D{NT, DT}}}) where {NT, DT} = zero(NT) / one(DT)
Base.one(::Union{D{NT, DT}, Type{D{NT, DT}}}) where {NT, DT} = one(NT) / one(DT)
Base.convert(::Type{D{NT, DT}}, a::Union{A, B, C}) where {NT, DT} = NT(a) / one(DT)
#Base.convert(::Type{D{NT, DT}}, a::D) where {NT, DT} = NT(a.n) / DT(a.d)

Base.:(*)(a::D, b::D) = (a.n*b.n) / (a.d*b.d)
Base.:(*)(a::D, b::Union{A,B,C}) = (a.n * b) / a.d
Base.:(*)(a::Union{A,B,C}, b::D) = b * a
Base.inv(a::Union{A,B,C}) = A(1) / a
Base.inv(a::D) = a.d / a.n
Base.:(/)(a::Union{A,B,C}, b::Union{A,B,C}) = D(a, b)
Base.:(/)(a::D, b::Union{A,B,C}) = a.n / (a.d*b)
Base.:(/)(a::Union{A,B,C,D}, b::D) = a * inv(b)
Base.:(+)(a::Union{A,B,C}, b::D) = (a*b.d+b.n) / b.d
Base.:(+)(a::D, b::Union{A,B,C}) = b + a
Base.:(+)(a::D, b::D) = (a.n*b.d+a.d*b.n) / (a.d*b.d)
Base.:(-)(a::Union{A,B,C}) = typeof(a)(a.x)
Base.:(-)(a::D) = (-a.n) / a.d
Base.:(-)(a::Union{A,B,C,D}, b::Union{A,B,C,D}) = a + (-b)

Base.promote_rule(::Type{A}, ::Type{B}) = B
Base.promote_rule(::Type{B}, ::Type{A}) = B
Base.promote_rule(::Type{A}, ::Type{C}) = C
Base.promote_rule(::Type{C}, ::Type{A}) = C
Base.promote_rule(::Type{B}, ::Type{C}) = C
Base.promote_rule(::Type{C}, ::Type{B}) = C
Base.promote_rule(::Type{D{NT,DT}}, T::Type{<:Union{A,B,C}}) where {NT,DT} = D{promote_type(NT,T),DT}
Base.promote_rule(T::Type{<:Union{A,B,C}}, ::Type{D{NT,DT}}) where {NT,DT} = D{promote_type(NT,T),DT}
Base.promote_rule(::Type{D{NS,DS}}, ::Type{D{NT,DT}}) where {NS,DS,NT,DT} = D{promote_type(NS,NT),promote_type(DS,DT)}
end