@@ -539,4 +539,54 @@ function LinearAlgebra.generic_mattridiv!(
539
539
return C
540
540
end
541
541
542
+ # LU Factorization
543
+ # # Generalized because we allow for batched LU
544
+ struct GeneralizedLU{T,S<: AbstractArray ,P<: AbstractArray ,I<: Union{AbstractArray,Number} } < :
545
+ Factorization{T}
546
+ factors:: S
547
+ ipiv:: P
548
+ perm:: P
549
+ info:: I
550
+ end
551
+
552
+ function GeneralizedLU (factors:: S , ipiv:: P , perm:: P , info:: I ) where {S,P,I}
553
+ return GeneralizedLU {eltype(factors),S,P,I} (factors, ipiv, perm, info)
554
+ end
555
+
556
+ function LinearAlgebra. lu! (
557
+ A:: AnyTracedRArray{T,N} , :: RowMaximum ; check:: Bool = false , allowsingular:: Bool = false
558
+ ) where {T,N}
559
+ return _lu_overload (A, RowMaximum (); check, allowsingular)
560
+ end
561
+
562
+ function LinearAlgebra. lu! (
563
+ A:: AnyTracedRArray{T,2} , :: RowMaximum ; check:: Bool = false , allowsingular:: Bool = false
564
+ ) where {T}
565
+ return _lu_overload (A, RowMaximum (); check, allowsingular)
566
+ end
567
+
568
+ function _lu_overload (
569
+ A:: AnyTracedRArray{T,N} , :: RowMaximum ; check:: Bool = false , allowsingular:: Bool = false
570
+ ) where {T,N}
571
+ # TODO : don't ignore the check and allowsingular flags
572
+ factors, ipiv, perm, info = Reactant. Ops. lu (materialize_traced_array (A))
573
+ return GeneralizedLU (factors, ipiv, perm, info)
574
+ end
575
+
576
+ # TODO : generalize for higher dimensions of B
577
+ function LinearAlgebra. ldiv! (
578
+ lu:: GeneralizedLU{T,<:AbstractMatrix,P,I} , B:: AbstractVector
579
+ ) where {T,P,I}
580
+ ldiv! (lu, reshape (B, :, 1 ))
581
+ return B
582
+ end
583
+
584
+ function LinearAlgebra. ldiv! (
585
+ lu:: GeneralizedLU{T,<:AbstractMatrix,P,I} , B:: AbstractMatrix
586
+ ) where {T,P,I}
587
+ ldiv! (B, UnitLowerTriangular (lu. factors), B[Int64 .(lu. perm), :])
588
+ ldiv! (B, UpperTriangular (lu. factors), B)
589
+ return B
590
+ end
591
+
542
592
end
0 commit comments