Skip to content

Commit

Permalink
Implement sensitivities for BLAS.gemm
Browse files Browse the repository at this point in the history
These are ported from Nabla.
  • Loading branch information
ararslan committed Apr 29, 2019
1 parent c5fd246 commit f7a81d5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/rules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ These implementations were ported from the wonderful DiffLinearAlgebra
package (https://github.com/invenia/DiffLinearAlgebra.jl).
=#

using LinearAlgebra: BlasFloat
using LinearAlgebra.BLAS: gemm

_zeros(x) = fill!(similar(x), zero(eltype(x)))

_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : (extern(ΔΩ)))
Expand Down Expand Up @@ -72,3 +75,40 @@ function rrule(f::typeof(BLAS.gemv), tA, A, x)
Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x)
return Ω, (dtA, dA, dx)
end

#####
##### `BLAS.gemm`
#####

function rrule(::typeof(gemm), tA::Char, tB::Char, α::T,
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
C = gemm(tA, tB, α, A, B)
# Note that one can actually differentiate w.r.t. `α` but it seems to be tricky to
# define the sensitivity `ᾱ` correctly. For now we'll just lie and say that the
# derivative doesn't exist.
#∂α = C̄ -> sum(C̄ .* C)
if uppercase(tA) === 'N'
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, C̄, B)
∂B =-> gemm('T', 'N', α, A, C̄)
else
∂A =-> gemm('N', 'N', α, C̄, B)
∂B =-> gemm('T', 'N', α, C̄, A)
end
else
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, B, C̄)
∂B =-> gemm('N', 'N', α, A, C̄)
else
∂A =-> gemm('T', 'T', α, B, C̄)
∂B =-> gemm('T', 'T', α, C̄, A)
end
end
return C, (DNERule(), DNERule(), DNERule(), _rule_via(∂A), _rule_via(∂B))
end

function rrule(::typeof(gemm), tA::Char, tB::Char,
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B)
return C, (dtA, dtB, dA, dB)
end
26 changes: 26 additions & 0 deletions test/rules/blas.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using LinearAlgebra.BLAS: gemm

@testset "BLAS" begin
@testset "gemm" begin
rng = MersenneTwister(123456)
n = 10
α = randn(rng)
A = randn(rng, n, n)
B = randn(rng, n, n)
for tA in ('N', 'T'), tB in ('N', 'T')
C = gemm(tA, tB, α, A, B)
fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B)
@test C fAB
@test dtA isa ChainRules.DNERule
@test dtB isa ChainRules.DNERule
for (f, x, dx) in [#=(X->gemm(tA, tB, X, A, B), α, dα),=#
(X->gemm(tA, tB, α, X, B), A, dA),
(X->gemm(tA, tB, α, A, X), B, dB)]
= randn(rng, size(x)...)
x̄_ad = dx(ȳ)
x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x)
@test x̄_ad x̄_fd rtol=1e-9 atol=1e-9
end
end
end
end

0 comments on commit f7a81d5

Please sign in to comment.