Skip to content

Commit

Permalink
Implement sensitivities for BLAS.gemm
Browse files Browse the repository at this point in the history
These are ported from Nabla.
  • Loading branch information
ararslan committed May 11, 2019
1 parent 6d2be82 commit 71c018d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/rules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ These implementations were ported from the wonderful DiffLinearAlgebra
package (https://github.com/invenia/DiffLinearAlgebra.jl).
=#

using LinearAlgebra: BlasFloat
using LinearAlgebra.BLAS: gemm

_zeros(x) = fill!(similar(x), zero(eltype(x)))

_rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : (extern(ΔΩ)))
Expand Down Expand Up @@ -72,3 +75,37 @@ function rrule(f::typeof(BLAS.gemv), tA, A, x)
Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x)
return Ω, (dtA, dA, dx)
end

#####
##### `BLAS.gemm`
#####

function rrule(::typeof(gemm), tA::Char, tB::Char, α::T,
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
C = gemm(tA, tB, α, A, B)
∂α =-> sum(C̄ .* C) / α
if uppercase(tA) === 'N'
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, C̄, B)
∂B =-> gemm('T', 'N', α, A, C̄)
else
∂A =-> gemm('N', 'N', α, C̄, B)
∂B =-> gemm('T', 'N', α, C̄, A)
end
else
if uppercase(tB) === 'N'
∂A =-> gemm('N', 'T', α, B, C̄)
∂B =-> gemm('N', 'N', α, A, C̄)
else
∂A =-> gemm('T', 'T', α, B, C̄)
∂B =-> gemm('T', 'T', α, C̄, A)
end
end
return C, (DNERule(), DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂B))
end

function rrule(::typeof(gemm), tA::Char, tB::Char,
A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat
C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B)
return C, (dtA, dtB, dA, dB)
end
26 changes: 26 additions & 0 deletions test/rules/blas.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using LinearAlgebra.BLAS: gemm

@testset "BLAS" begin
@testset "gemm" begin
rng = MersenneTwister(1)
dims = 3:5
for m in dims, n in dims, p in dims, tA in ('N', 'T'), tB in ('N', 'T')
α = randn(rng)
A = randn(rng, tA === 'N' ? (m, n) : (n, m))
B = randn(rng, tB === 'N' ? (n, p) : (p, n))
C = gemm(tA, tB, α, A, B)
fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B)
@test C fAB
@test dtA isa ChainRules.DNERule
@test dtB isa ChainRules.DNERule
for (f, x, dx) in [(X->gemm(tA, tB, X, A, B), α, dα),
(X->gemm(tA, tB, α, X, B), A, dA),
(X->gemm(tA, tB, α, A, X), B, dB)]
= randn(rng, size(C)...)
x̄_ad = dx(ȳ)
x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x)
@test x̄_ad x̄_fd rtol=1e-9 atol=1e-9
end
end
end
end

0 comments on commit 71c018d

Please sign in to comment.