-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Cholesky factorisation #202
Changes from all commits
9353ecf
221ec84
7b80bc3
0c36a9d
83bdaaf
ce8769b
78b180d
585c946
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "NamedDims" | ||
uuid = "356022a1-0364-5f58-8944-0da4b18d706f" | ||
authors = ["Invenia Technical Computing Corporation"] | ||
version = "0.2.49" | ||
version = "0.3.0" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. major bump due to constrainting julia to 1.6 and above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the future, this should be a patch bump to be non-breaking (since the package is pre-1.0) unless one intends to support backports: https://github.com/SciML/ColPrac/pull/20/files |
||
|
||
[deps] | ||
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" | ||
|
@@ -20,7 +20,7 @@ ChainRulesTestUtils = "1" | |
CovarianceEstimation = "0.2.4" | ||
Requires = "0.5, 1" | ||
Tracker = "0.2.2" | ||
julia = "1" | ||
julia = "1.6" | ||
|
||
[extras] | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,12 +23,13 @@ Base.size(named::NamedFactorization) = size(parent(named)) | |
Base.propertynames(named::NamedFactorization; kwargs...) = propertynames(parent(named)) | ||
|
||
# Factorization type specific initial iterate calls | ||
Base.iterate(named::NamedFactorization{L, T, <:LU}) where {L, T} = (named.L, Val(:U)) | ||
Base.iterate(named::NamedFactorization{L, T, <:LQ}) where {L, T} = (named.L, Val(:Q)) | ||
Base.iterate(named::NamedFactorization{L, T, <:SVD}) where {L, T} = (named.U, Val(:S)) | ||
Base.iterate(named::NamedFactorization{L,T,<:LU}) where {L,T} = (named.L, Val(:U)) | ||
Base.iterate(named::NamedFactorization{L,T,<:LQ}) where {L,T} = (named.L, Val(:Q)) | ||
Base.iterate(named::NamedFactorization{L,T,<:SVD}) where {L,T} = (named.U, Val(:S)) | ||
Base.iterate(named::NamedFactorization{L,T,<:Cholesky}) where {L,T} = (named.L, Val(:U)) | ||
function Base.iterate( | ||
named::NamedFactorization{L, T, <:Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}} | ||
) where {L, T} | ||
named::NamedFactorization{L,T,<:Union{QR,LinearAlgebra.QRCompactWY,QRPivoted}} | ||
) where {L,T} | ||
return (named.Q, Val(:R)) | ||
end | ||
|
||
|
@@ -40,9 +41,11 @@ function Base.iterate(named::NamedFactorization, st::Val{D}) where D | |
end | ||
|
||
# Convenience constructors | ||
for func in (:lu, :lu!, :lq, :lq!, :svd, :svd!, :qr, :qr!) | ||
for func in (:lu, :lu!, :lq, :lq!, :svd, :svd!, :qr, :qr!, :cholesky) | ||
@eval begin | ||
function LinearAlgebra.$func(nda::NamedDimsArray{L, T}, args...; kwargs...) where {L, T} | ||
function LinearAlgebra.$func( | ||
nda::NamedDimsArray{L,T}, args...; kwargs... | ||
) where {L,T} | ||
return NamedFactorization{L}($func(parent(nda), args...; kwargs...)) | ||
end | ||
end | ||
|
@@ -82,8 +85,24 @@ function Base.getproperty(fact::NamedFactorization{L, T, <:LQ}, d::Symbol) where | |
end | ||
end | ||
|
||
# cholesky | ||
|
||
function Base.getproperty(fact::NamedFactorization{L,T,<:Cholesky}, d::Symbol) where {L,T} | ||
inner = getproperty(parent(fact), d) | ||
n1, n2 = L | ||
return d in (:L, :U) ? NamedDimsArray{(n1, n2)}(inner) : inner | ||
end | ||
function NamedFactorization{L}(fact::Cholesky{T}) where {L,T} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Define a new constructor to handle the dimension mismatch |
||
n1, n2 = L | ||
return if isequal(n1, n2) | ||
NamedFactorization{L,T,Cholesky{T}}(fact) | ||
else | ||
throw(DimensionMismatch("$n1 != $n2")) | ||
end | ||
end | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## svd | ||
function Base.getproperty(fact::NamedFactorization{L, T, <:SVD}, d::Symbol) where {L, T} | ||
function Base.getproperty(fact::NamedFactorization{L,T,<:SVD}, d::Symbol) where {L,T} | ||
inner = getproperty(parent(fact), d) | ||
n1, n2 = L | ||
# Naming based off the SVD visualization on wikipedia | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,13 +10,19 @@ if !isdefined(@__MODULE__, :ColumnNorm) | |
NoPivot() = Val(false) | ||
end | ||
|
||
function baseline_tests(fact, identity) | ||
_test_data(::Val{:rectangle}) = [1.0 2 3; 4 5 6]; | ||
function _test_data(::Val{:pdmat}) | ||
return [8.0 7.0 6.0 5.0; 7.0 8.0 6.0 6.0; 6.0 6.0 6.0 5.0; 5.0 6.0 5.0 5.0] | ||
end | ||
_test_names(::Val{:rectangle}) = (:foo, :bar) | ||
_test_names(::Val{:pdmat}) = (:foo, :foo) | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function baseline_tests(fact, identity; test_data_type=:rectangle) | ||
# A set of generic tests to ensure that our components don't accidentally reverse the | ||
# `:foo` and `:bar` labels for any components | ||
@testset "Baseline" begin | ||
names = (:foo, :bar) | ||
sz = (2, 3) | ||
data = [1.0 2 3; 4 5 6] | ||
names = _test_names(Val{test_data_type}()) | ||
data = _test_data(Val{test_data_type}()) | ||
nda = NamedDimsArray{names}(data) | ||
|
||
base_fact = fact(data) | ||
|
@@ -38,10 +44,11 @@ function baseline_tests(fact, identity) | |
@test size(_base) == size(_named) | ||
|
||
# If our property is a NamedDimsArray make sure that the names make sense | ||
_named isa NamedDimsArray && @testset "Test name for dim $d" for d in 1:ndims(_named) | ||
_named isa NamedDimsArray && @testset "Test name for dim $d" for d in | ||
1:ndims(_named) | ||
# Don't think it make sense for an factorization to produce properties with | ||
# dimension sizes outside 1, 2 or 3 | ||
@test size(_named, d) in (1, 2, 3) | ||
AlexRobson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test d in (1, 2, 3) | ||
|
||
if size(_named, d) == 1 | ||
# Neither name makes sense here | ||
|
@@ -52,6 +59,9 @@ function baseline_tests(fact, identity) | |
elseif size(_named, d) == 3 | ||
# Name must either be :bar or :_ | ||
@test dimnames(_named, d) in (:bar, :_) | ||
elseif size(_named, d) == 4 | ||
# Name can only be foo, as this is the pdmat case | ||
@test dimnames(_named, d) in (:foo,) | ||
end | ||
end | ||
end | ||
|
@@ -143,6 +153,21 @@ end | |
end | ||
end | ||
|
||
@testset "cholesky" begin | ||
baseline_tests(cholesky, S -> S.L * S.L'; test_data_type=:pdmat) | ||
baseline_tests(cholesky, S -> S.U' * S.U; test_data_type=:pdmat) | ||
|
||
# Explicit `dimnames` tests for readability | ||
nda = NamedDimsArray{(:foo, :foo)}(_test_data(Val{:pdmat}())) | ||
nda_mismatch = NamedDimsArray{(:foo, :bar)}(_test_data(Val{:pdmat}())) | ||
x = cholesky(nda) | ||
@test size(x) == size(parent(x)) | ||
@test dimnames(x.L) == (:foo, :foo) | ||
@test dimnames(x.U) == (:foo, :foo) | ||
|
||
@test_throws DimensionMismatch cholesky(nda_mismatch) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now throw an error if there is a dim-mismatch. |
||
end | ||
|
||
@testset "#164 factorization eltype not same as input eltype" begin | ||
# https://github.com/invenia/NamedDims.jl/issues/164 | ||
nda = NamedDimsArray{(:foo, :bar)}([1 2 3; 4 5 6; 7 8 9]) # Int eltype | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
drop julia < 1.6 in CI