Skip to content

Commit eec2f07

Browse files
committed
feat: overload LinearAlgebra.lu
1 parent 58a33a2 commit eec2f07

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/stdlibs/LinearAlgebra.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,4 +539,54 @@ function LinearAlgebra.generic_mattridiv!(
539539
return C
540540
end
541541

542+
# LU Factorization
543+
## Generalized because we allow for batched LU
544+
struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
545+
Factorization{T}
546+
factors::S
547+
ipiv::P
548+
perm::P
549+
info::I
550+
end
551+
552+
function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
553+
return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
554+
end
555+
556+
function LinearAlgebra.lu!(
557+
A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false
558+
) where {T,N}
559+
return _lu_overload(A, RowMaximum(); check, allowsingular)
560+
end
561+
562+
function LinearAlgebra.lu!(
563+
A::AnyTracedRArray{T,2}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false
564+
) where {T}
565+
return _lu_overload(A, RowMaximum(); check, allowsingular)
566+
end
567+
568+
function _lu_overload(
569+
A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false
570+
) where {T,N}
571+
# TODO: don't ignore the check and allowsingular flags
572+
factors, ipiv, perm, info = Reactant.Ops.lu(materialize_traced_array(A))
573+
return GeneralizedLU(factors, ipiv, perm, info)
574+
end
575+
576+
# TODO: generalize for higher dimensions of B
577+
function LinearAlgebra.ldiv!(
578+
lu::GeneralizedLU{T,<:AbstractMatrix,P,I}, B::AbstractVector
579+
) where {T,P,I}
580+
ldiv!(lu, reshape(B, :, 1))
581+
return B
582+
end
583+
584+
function LinearAlgebra.ldiv!(
585+
lu::GeneralizedLU{T,<:AbstractMatrix,P,I}, B::AbstractMatrix
586+
) where {T,P,I}
587+
ldiv!(B, UnitLowerTriangular(lu.factors), B[Int64.(lu.perm), :])
588+
ldiv!(B, UpperTriangular(lu.factors), B)
589+
return B
590+
end
591+
542592
end

0 commit comments

Comments
 (0)