Skip to content

Commit

Permalink
Add rules for sqrt(::Diagonal) (#509)
Browse files Browse the repository at this point in the history
* WIP add diag sqrt rule

* Adding test with different rand_tangent()

* michaels less allocating suggestion

* spaces and unthunk

* rand the tests

* fix CI

* prevent stochastic failures

Co-authored-by: Miha Zgubic <miha.zgubic@invenialabs.co.uk>
  • Loading branch information
thomasgudjonwright and Miha Zgubic authored Aug 24, 2021
1 parent 26e4608 commit 649bfbb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.9.0"
version = "1.10.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
13 changes: 13 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ end
##### `Diagonal`
#####

_diagview(x::Diagonal) = x.diag
_diagview(x::AbstractMatrix) = view(x, diagind(x))
_diagview(x::Tangent{<:Diagonal}) = x.diag
function ChainRulesCore.rrule(::typeof(sqrt), d::Diagonal)
y = sqrt(d)
@assert y isa Diagonal
function sqrt_pullback(Δ)
Δ_diag = _diagview(unthunk(Δ))
return NoTangent(), Diagonal(Δ_diag ./ (2 .* y.diag))
end
return y, sqrt_pullback
end

# these functions are defined outside the rrule because otherwise type inference breaks
# see https://github.com/JuliaLang/julia/issues/40990
_Diagonal_pullback(ȳ::AbstractMatrix) = return (NoTangent(), diag(ȳ)) # should we emit a warning here? this shouldn't be called if project works right
Expand Down
4 changes: 4 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
end
end
end
@testset "sqrt(::Diagonal)" begin
test_rrule(sqrt, Diagonal(rand(3)))
test_rrule(sqrt, Diagonal([1.0, 2]); output_tangent=[1.2 3.4; 1.2 4.3])
end
@testset "$f, $T" for
f in (Adjoint, adjoint, Transpose, transpose),
T in (Float64, ComplexF64)
Expand Down

2 comments on commit 649bfbb

@thomasgudjonwright
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/43448

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.10.0 -m "<description of version>" 649bfbbfb27b0d40310bf0f5591b84c4d6091675
git push origin v1.10.0

Please sign in to comment.