Skip to content

Commit 37fb249

Browse files
committed
Some fixes
1 parent 87309bf commit 37fb249

File tree

4 files changed

+53
-13
lines changed

4 files changed

+53
-13
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ end
227227

228228
for (f, f_ne, pb, pf, adj) in (
229229
(:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_pushforward!, :eig_trunc_adjoint),
230-
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_pushforward, :eigh_trunc_adjoint),
230+
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_pushforward!, :eigh_trunc_adjoint),
231231
)
232232
@eval begin
233233
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
@@ -254,7 +254,20 @@ for (f, f_ne, pb, pf, adj) in (
254254
end
255255
return output_codual, $adj
256256
end
257-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
257+
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
258+
# compute primal
259+
A, dA = arrayify(A_dA)
260+
alg = Mooncake.primal(alg_dalg)
261+
output = $f(A, alg)
262+
output_dual = Mooncake.zero_dual(output)
263+
dD_ = Mooncake.tangent(output_dual)[1]
264+
dV_ = Mooncake.tangent(output_dual)[2]
265+
D, dD = arrayify(output[1], dD_)
266+
V, dV = arrayify(output[2], dV_)
267+
$pf(dA, A, (D, V), (dD, dV))
268+
return output_dual
269+
end
270+
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
258271
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
259272
# compute primal
260273
A, dA = arrayify(A_dA)
@@ -277,11 +290,11 @@ for (f, f_ne, pb, pf, adj) in (
277290
end
278291
return output_codual, $adj
279292
end
280-
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
293+
function Mooncake.frule!!(::Dual{typeof($f_ne)}, A_dA::Dual, alg_dalg::Dual)
281294
# compute primal
282295
A, dA = arrayify(A_dA)
283296
alg = Mooncake.primal(alg_dalg)
284-
output = $f(A, alg)
297+
output = $f_ne(A, alg)
285298
output_dual = Mooncake.zero_dual(output)
286299
dD_ = Mooncake.tangent(output_dual)[1]
287300
dV_ = Mooncake.tangent(output_dual)[2]
@@ -478,9 +491,7 @@ end
478491

479492
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
480493
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
481-
A_ = Mooncake.primal(A_dA)
482-
dA_ = Mooncake.tangent(A_dA)
483-
A, dA = arrayify(A_, dA_)
494+
A, dA = arrayify(A_dA)
484495
alg = Mooncake.primal(alg_dalg)
485496
output = svd_trunc(A, alg)
486497
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -531,12 +542,10 @@ function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual)
531542
end
532543

533544

534-
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
545+
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
535546
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
536547
# compute primal
537-
A_ = Mooncake.primal(A_dA)
538-
dA_ = Mooncake.tangent(A_dA)
539-
A, dA = arrayify(A_, dA_)
548+
A, dA = arrayify(A_dA)
540549
alg = Mooncake.primal(alg_dalg)
541550
output = svd_trunc_no_error(A, alg)
542551
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -559,4 +568,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
559568
return output_codual, svd_trunc_adjoint
560569
end
561570

571+
function Mooncake.frule!!(::Dual{typeof(svd_trunc_no_error)}, A_dA::Dual, alg_dalg::Dual)
572+
# compute primal
573+
A, dA = arrayify(A_dA)
574+
alg = Mooncake.primal(alg_dalg)
575+
USVᴴ = svd_compact(A, alg.alg)
576+
U, S, Vᴴ = USVᴴ
577+
dUfull = zeros(eltype(U), size(U))
578+
dSfull = Diagonal(zeros(eltype(S), length(diagview(S))))
579+
dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ))
580+
svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull))
581+
582+
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
583+
output = USVᴴtrunc
584+
output_dual = Mooncake.zero_dual(output)
585+
Utrunc, Strunc, Vᴴtrunc = output
586+
dU_, dS_, dVᴴ_ = Mooncake.tangent(output_dual)
587+
Utrunc, dU = arrayify(Utrunc, dU_)
588+
Strunc, dS = arrayify(Strunc, dS_)
589+
Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_)
590+
dU .= view(dUfull, :, ind)
591+
diagview(dS) .= view(diagview(dSfull), ind)
592+
dVᴴ .= view(dVᴴfull, ind, :)
593+
return output_dual
594+
end
595+
562596
end

src/pushforwards/eig.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
1313
end
1414

1515
function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end
16+
17+
function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...)
18+
return eig_pushforward!(ΔA, A, DV, ΔD; kwargs...)
19+
end

src/pushforwards/eigh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
1717
end
1818

1919
function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end
20+
21+
function eigh_vals_pushforward!(dA, A, DV, dDV, ind = Colon(); kwargs...) end

src/pushforwards/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol = default_pullback
4949
fill!(UÃÃV, 0)
5050
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
5151
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
52-
rhs = vcat(adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
52+
rhs = vcat(adjoint(Uperp * ΔA * Vᴴ), Vᴴperp * ΔA' * U)
5353
superKM = -sylvester(UÃÃV, Smat, rhs)
5454
K̇perp = view(superKM, 1:size(aUAV, 2))
5555
Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
5656
∂U .+= Uperp * K̇perp
57-
∂V .+= Vperp * Ṁperp
57+
∂V .+= Vᴴperp * Ṁperp
5858
else
5959
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU')
6060
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ)

0 commit comments

Comments
 (0)