227227
228228for (f, f_ne, pb, pf, adj) in (
229229 (:eig_trunc , :eig_trunc_no_error , :eig_trunc_pullback! , :eig_trunc_pushforward! , :eig_trunc_adjoint ),
230- (:eigh_trunc , :eigh_trunc_no_error , :eigh_trunc_pullback! , :eigh_trunc_pushforward , :eigh_trunc_adjoint ),
230+ (:eigh_trunc , :eigh_trunc_no_error , :eigh_trunc_pullback! , :eigh_trunc_pushforward! , :eigh_trunc_adjoint ),
231231 )
232232 @eval begin
233233 @is_primitive Mooncake. DefaultCtx Tuple{typeof ($ f), Any, MatrixAlgebraKit. AbstractAlgorithm}
@@ -254,7 +254,20 @@ for (f, f_ne, pb, pf, adj) in (
254254 end
255255 return output_codual, $ adj
256256 end
257- @is_primitive Mooncake. DefaultCtx Mooncake. ReverseMode Tuple{typeof ($ f_ne), Any, MatrixAlgebraKit. AbstractAlgorithm}
257+ function Mooncake. frule!! (:: Dual{typeof($f)} , A_dA:: Dual , alg_dalg:: Dual )
258+ # compute primal
259+ A, dA = arrayify (A_dA)
260+ alg = Mooncake. primal (alg_dalg)
261+ output = $ f (A, alg)
262+ output_dual = Mooncake. zero_dual (output)
263+ dD_ = Mooncake. tangent (output_dual)[1 ]
264+ dV_ = Mooncake. tangent (output_dual)[2 ]
265+ D, dD = arrayify (output[1 ], dD_)
266+ V, dV = arrayify (output[2 ], dV_)
267+ $ pf (dA, A, (D, V), (dD, dV))
268+ return output_dual
269+ end
270+ @is_primitive Mooncake. DefaultCtx Tuple{typeof ($ f_ne), Any, MatrixAlgebraKit. AbstractAlgorithm}
258271 function Mooncake. rrule!! (:: CoDual{typeof($f_ne)} , A_dA:: CoDual , alg_dalg:: CoDual )
259272 # compute primal
260273 A, dA = arrayify (A_dA)
@@ -277,11 +290,11 @@ for (f, f_ne, pb, pf, adj) in (
277290 end
278291 return output_codual, $ adj
279292 end
280- function Mooncake. frule!! (:: Dual{typeof($f )} , A_dA:: Dual , alg_dalg:: Dual )
293+ function Mooncake. frule!! (:: Dual{typeof($f_ne )} , A_dA:: Dual , alg_dalg:: Dual )
281294 # compute primal
282295 A, dA = arrayify (A_dA)
283296 alg = Mooncake. primal (alg_dalg)
284- output = $ f (A, alg)
297+ output = $ f_ne (A, alg)
285298 output_dual = Mooncake. zero_dual (output)
286299 dD_ = Mooncake. tangent (output_dual)[1 ]
287300 dV_ = Mooncake. tangent (output_dual)[2 ]
478491
479492@is_primitive Mooncake. DefaultCtx Tuple{typeof (svd_trunc), Any, MatrixAlgebraKit. AbstractAlgorithm}
480493function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc)} , A_dA:: CoDual , alg_dalg:: CoDual )
481- A_ = Mooncake. primal (A_dA)
482- dA_ = Mooncake. tangent (A_dA)
483- A, dA = arrayify (A_, dA_)
494+ A, dA = arrayify (A_dA)
484495 alg = Mooncake. primal (alg_dalg)
485496 output = svd_trunc (A, alg)
486497 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -531,12 +542,10 @@ function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual)
531542end
532543
533544
534- @is_primitive Mooncake. DefaultCtx Mooncake . ReverseMode Tuple{typeof (svd_trunc_no_error), Any, MatrixAlgebraKit. AbstractAlgorithm}
545+ @is_primitive Mooncake. DefaultCtx Tuple{typeof (svd_trunc_no_error), Any, MatrixAlgebraKit. AbstractAlgorithm}
535546function Mooncake. rrule!! (:: CoDual{typeof(svd_trunc_no_error)} , A_dA:: CoDual , alg_dalg:: CoDual )
536547 # compute primal
537- A_ = Mooncake. primal (A_dA)
538- dA_ = Mooncake. tangent (A_dA)
539- A, dA = arrayify (A_, dA_)
548+ A, dA = arrayify (A_dA)
540549 alg = Mooncake. primal (alg_dalg)
541550 output = svd_trunc_no_error (A, alg)
542551 # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
@@ -559,4 +568,29 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
559568 return output_codual, svd_trunc_adjoint
560569end
561570
571+ function Mooncake. frule!! (:: Dual{typeof(svd_trunc_no_error)} , A_dA:: Dual , alg_dalg:: Dual )
572+ # compute primal
573+ A, dA = arrayify (A_dA)
574+ alg = Mooncake. primal (alg_dalg)
575+ USVᴴ = svd_compact (A, alg. alg)
576+ U, S, Vᴴ = USVᴴ
577+ dUfull = zeros (eltype (U), size (U))
578+ dSfull = Diagonal (zeros (eltype (S), length (diagview (S))))
579+ dVᴴfull = zeros (eltype (Vᴴ), size (Vᴴ))
580+ svd_pushforward! (dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull))
581+
582+ USVᴴtrunc, ind = truncate (svd_trunc!, USVᴴ, alg. trunc)
583+ output = USVᴴtrunc
584+ output_dual = Mooncake. zero_dual (output)
585+ Utrunc, Strunc, Vᴴtrunc = output
586+ dU_, dS_, dVᴴ_ = Mooncake. tangent (output_dual)
587+ Utrunc, dU = arrayify (Utrunc, dU_)
588+ Strunc, dS = arrayify (Strunc, dS_)
589+ Vᴴtrunc, dVᴴ = arrayify (Vᴴtrunc, dVᴴ_)
590+ dU .= view (dUfull, :, ind)
591+ diagview (dS) .= view (diagview (dSfull), ind)
592+ dVᴴ .= view (dVᴴfull, ind, :)
593+ return output_dual
594+ end
595+
562596end
0 commit comments