Skip to content

Commit fe026b5

Browse files
committed
Fix / on 1.9
1 parent 55a48c6 commit fe026b5

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,20 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
342342
project_B = ProjectTo(B)
343343

344344
Y = A \ B
345+
# Ever since https://github.com/JuliaLang/julia/pull/44358
346+
# we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array
347+
# See also https://github.com/JuliaLang/julia/issues/28827 which would improve this
345348
function backslash_pullback(ȳ)
346349
= unthunk(ȳ)
350+
Ati = pinv(A')
347351
∂A = @thunk begin
348-
= A' \
352+
353+
= Ati *
349354
= -* Y'
350-
= add!!(Ā, (B - A * Y) *' / A')
351-
= add!!(Ā, A' \ Y * (Ȳ' -'A))
355+
= add!!(Ā, ((B - A * Y) *') * Ati)
356+
= add!!(Ā, Ati * Y * (Ȳ' -'A))
352357
project_A(Ā)
353358
end
354-
∂B = @thunk project_B(A' \ Ȳ)
359+
∂B = @thunk project_B(Ati * Ȳ)
355360
return NoTangent(), ∂A, ∂B
356361
end
357362
return Y, backslash_pullback
358-
359363
end
360364

361365
#####

0 commit comments

Comments
 (0)