Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ for eig in (:eig, :eigh)
alg::TruncatedAlgorithm
)
Ac = copy_input($eig_f, A)
D, V = $(eig_f!)(Ac, DV, alg.alg)
ind = findtruncated(diagview(D), alg.trunc)
return (Diagonal(diagview(D)[ind]), V[:, ind]),
$(_make_eig_t_pb)(A, (D, V), ind)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_pb)(A, DV, ind)
end
function $(_make_eig_t_pb)(A, DV, ind)
function $eig_t_pb(ΔDV)
Expand Down Expand Up @@ -163,10 +162,9 @@ function ChainRulesCore.rrule(
alg::TruncatedAlgorithm
)
Ac = copy_input(svd_compact, A)
U, S, Vᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
ind = findtruncated_svd(diagview(S), alg.trunc)
return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]),
_make_svd_trunc_pullback(A, (U, S, Vᴴ), ind)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴ)
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
eval(
Expr(
:public, :TruncationByOrder, :TruncationByFilter, :TruncationByValue,
:TruncationByError, :TruncationIntersection
:TruncationByError, :TruncationIntersection, :truncate
)
)
eval(
Expand Down
20 changes: 12 additions & 8 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ this function may return `nothing`.

Supertype to denote different strategies for truncated decompositions that are implemented via post-truncation.

See also [`truncate!`](@ref)
See also [`truncate`](@ref)
"""
abstract type TruncationStrategy end

Expand Down Expand Up @@ -166,7 +166,7 @@ end
Generic interface for finding truncated values of the spectrum of a decomposition
based on the `strategy`. The output should be a collection of indices specifying
which values to keep. `MatrixAlgebraKit.findtruncated` is used inside of the default
implementation of [`truncate!`](@ref) to perform the truncation. It does not assume that the
implementation of [`truncate`](@ref) to perform the truncation. It does not assume that the
values are sorted. For a version that assumes the values are reverse sorted (which is the
standard case for SVD) see [`MatrixAlgebraKit.findtruncated_svd`](@ref).
""" findtruncated
Expand All @@ -179,6 +179,16 @@ sorted in descending order, as typically obtained by the SVD. This assumption is
checked, and this is used in the default implementation of [`svd_trunc!`](@ref).
""" findtruncated_svd

@doc """
truncate(::typeof(f), F, strategy::TruncationStrategy) -> F′, ind

Given a factorization function `f` and truncation `strategy`, truncate the factors `F` such
that the rows or columns at the indices `ind` are kept.

See also [`findtruncated`](@ref) and [`findtruncated_svd`](@ref) for determining the indices.
"""
function truncate end

"""
TruncatedAlgorithm(alg::AbstractAlgorithm, trunc::TruncationAlgorithm)

Expand All @@ -190,12 +200,6 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
trunc::T
end

@doc """
truncate!(f, out, strategy::TruncationStrategy)

Generic interface for post-truncating a decomposition, specified in `out`.
""" truncate!

# Utility macros
# --------------

Expand Down
2 changes: 1 addition & 1 deletion src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ end

function eig_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
return truncate!(eig_trunc!, (D, V), alg.trunc)
return first(truncate(eig_trunc!, (D, V), alg.trunc))
end

# Diagonal logic
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ end

function eigh_trunc!(A::AbstractMatrix, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
return truncate!(eigh_trunc!, (D, V), alg.trunc)
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
end

# Diagonal logic
Expand Down
4 changes: 2 additions & 2 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function left_null_svd!(A, N, alg, trunc)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(left_null!, (U, S), trunc′)
return first(truncate(left_null!, (U, S), trunc′))
end

function right_null!(
Expand Down Expand Up @@ -287,5 +287,5 @@ function right_null_svd!(A, Nᴴ, alg, trunc)
trunc′ = trunc isa TruncationStrategy ? trunc :
trunc isa NamedTuple ? null_truncation_strategy(; trunc...) :
throw(ArgumentError("Unknown truncation strategy: $trunc"))
return truncate!(right_null!, (S, Vᴴ), trunc′)
return first(truncate(right_null!, (S, Vᴴ), trunc′))
end
4 changes: 2 additions & 2 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ end

function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm)
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
return truncate!(svd_trunc!, USVᴴ′, alg.trunc)
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
end

# Diagonal logic
Expand Down Expand Up @@ -383,7 +383,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
# TODO: make this controllable using a `gaugefix` keyword argument
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
return truncate!(svd_trunc!, USVᴴ, alg.trunc)
return first(truncate(svd_trunc!, USVᴴ, alg.trunc))
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
Expand Down
24 changes: 12 additions & 12 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# truncate!
# ---------
# truncate
# --------
# Generic implementation: `findtruncated` followed by indexing
function truncate!(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
function truncate(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated_svd(diagview(S), strategy)
return U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]
return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]), ind
end
function truncate!(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
function truncate(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
return (Diagonal(diagview(D)[ind]), V[:, ind]), ind
end
function truncate!(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
function truncate(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return Diagonal(diagview(D)[ind]), V[:, ind]
return (Diagonal(diagview(D)[ind]), V[:, ind]), ind
end
function truncate!(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
ind = findtruncated(extended_S, strategy)
return U[:, ind]
return U[:, ind], ind
end
function truncate!(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
ind = findtruncated(extended_S, strategy)
return Vᴴ[ind, :]
return Vᴴ[ind, :], ind
end

# findtruncated
Expand Down