diff --git a/src/rules/blas.jl b/src/rules/blas.jl index 9c2a64787..d5cfc700d 100644 --- a/src/rules/blas.jl +++ b/src/rules/blas.jl @@ -3,6 +3,9 @@ These implementations were ported from the wonderful DiffLinearAlgebra package (https://github.com/invenia/DiffLinearAlgebra.jl). =# +using LinearAlgebra: BlasFloat +using LinearAlgebra.BLAS: gemm + _zeros(x) = fill!(similar(x), zero(eltype(x))) _rule_via(∂) = Rule(ΔΩ -> isa(ΔΩ, Zero) ? ΔΩ : ∂(extern(ΔΩ))) @@ -72,3 +75,37 @@ function rrule(f::typeof(BLAS.gemv), tA, A, x) Ω, (dtA, dα, dA, dx) = rrule(f, tA, one(eltype(A)), A, x) return Ω, (dtA, dA, dx) end + +##### +##### `BLAS.gemm` +##### + +function rrule(::typeof(gemm), tA::Char, tB::Char, α::T, + A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat + C = gemm(tA, tB, α, A, B) + ∂α = C̄ -> sum(C̄ .* C) / α + if uppercase(tA) === 'N' + if uppercase(tB) === 'N' + ∂A = C̄ -> gemm('N', 'T', α, C̄, B) + ∂B = C̄ -> gemm('T', 'N', α, A, C̄) + else + ∂A = C̄ -> gemm('N', 'N', α, C̄, B) + ∂B = C̄ -> gemm('T', 'N', α, C̄, A) + end + else + if uppercase(tB) === 'N' + ∂A = C̄ -> gemm('N', 'T', α, B, C̄) + ∂B = C̄ -> gemm('N', 'N', α, A, C̄) + else + ∂A = C̄ -> gemm('T', 'T', α, B, C̄) + ∂B = C̄ -> gemm('T', 'T', α, C̄, A) + end + end + return C, (DNERule(), DNERule(), _rule_via(∂α), _rule_via(∂A), _rule_via(∂B)) +end + +function rrule(::typeof(gemm), tA::Char, tB::Char, + A::AbstractMatrix{T}, B::AbstractMatrix{T}) where T<:BlasFloat + C, (dtA, dtB, _, dA, dB) = rrule(gemm, tA, tB, one(T), A, B) + return C, (dtA, dtB, dA, dB) +end diff --git a/test/rules/blas.jl b/test/rules/blas.jl index e69de29bb..5eb6c5123 100644 --- a/test/rules/blas.jl +++ b/test/rules/blas.jl @@ -0,0 +1,26 @@ +using LinearAlgebra.BLAS: gemm + +@testset "BLAS" begin + @testset "gemm" begin + rng = MersenneTwister(1) + dims = 3:5 + for m in dims, n in dims, p in dims, tA in ('N', 'T'), tB in ('N', 'T') + α = randn(rng) + A = randn(rng, tA === 'N' ? (m, n) : (n, m)) + B = randn(rng, tB === 'N' ? (n, p) : (p, n)) + C = gemm(tA, tB, α, A, B) + fAB, (dtA, dtB, dα, dA, dB) = rrule(gemm, tA, tB, α, A, B) + @test C ≈ fAB + @test dtA isa ChainRules.DNERule + @test dtB isa ChainRules.DNERule + for (f, x, dx) in [(X->gemm(tA, tB, X, A, B), α, dα), + (X->gemm(tA, tB, α, X, B), A, dA), + (X->gemm(tA, tB, α, A, X), B, dB)] + ȳ = randn(rng, size(C)...) + x̄_ad = dx(ȳ) + x̄_fd = j′vp(central_fdm(5, 1), f, ȳ, x) + @test x̄_ad ≈ x̄_fd rtol=1e-9 atol=1e-9 + end + end + end +end