Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct chainrules for abs2, abs, conj and angle #196

Merged
merged 34 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6101dfa
restrict abs2 to ::Real
MasonProtter May 20, 2020
70f3f94
add Seth's frules
MasonProtter May 23, 2020
40ccb2b
Update src/rulesets/Base/fastmath_able.jl
MasonProtter May 23, 2020
b653379
Update src/rulesets/Base/fastmath_able.jl
MasonProtter May 23, 2020
4f7bd85
updates
MasonProtter May 27, 2020
be7c008
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 1, 2020
78578f3
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 1, 2020
abb15e0
don't test conj and angle on complex inputs (yet)
MasonProtter Jun 1, 2020
9cf260c
put in missing factor of two
MasonProtter Jun 23, 2020
98f0c73
drop incorrect conj
MasonProtter Jun 23, 2020
33f7b1f
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 23, 2020
9fe148d
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 23, 2020
9f4344e
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 23, 2020
a83cd0f
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 23, 2020
81862c0
Merge github.com:JuliaDiff/ChainRules.jl into patch-1
MasonProtter Jun 23, 2020
f8de55a
fix testing for abs and abs2 (still needs a refactor)
MasonProtter Jun 23, 2020
5c1b65b
test fastmath versions as well
MasonProtter Jun 23, 2020
250c22e
add rules for conj and angle
MasonProtter Jun 23, 2020
2f74b6a
fixes and cleanup
MasonProtter Jun 24, 2020
538c798
testing fixes
MasonProtter Jun 24, 2020
8bfefd2
testing fixes again
MasonProtter Jun 24, 2020
66e312d
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 25, 2020
49cf303
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 25, 2020
9706f91
use subgradient convention and special-case angle for reals
MasonProtter Jun 26, 2020
546ad0b
remove extra code that wasn't needed
MasonProtter Jun 26, 2020
761fbd1
consolidate abs2 rrules
MasonProtter Jun 26, 2020
de13870
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
16518ec
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
be2bf11
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
f2d5f86
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
d531bb9
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
860d7f8
Update src/rulesets/Base/fastmath_able.jl
MasonProtter Jun 26, 2020
ec18a77
increment version
MasonProtter Jun 27, 2020
50cc4ce
Merge branch 'patch-1' of github.com:MasonProtter/ChainRules.jl into …
MasonProtter Jun 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 82 additions & 4 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,88 @@ let


# Unary complex functions
@scalar_rule abs(x::Real) sign(x)
@scalar_rule abs2(x) 2x
@scalar_rule angle(x::Real) Zero()
@scalar_rule conj(x::Real) One()
## abs
function frule((_, Δx), ::typeof(abs), x::Real)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return abs(x), sign(x) * real(Δx)
end
function frule((_, Δz), ::typeof(abs), z::Complex)
Ω = abs(z)
return Ω, (real(z) * real(Δz) + imag(z) * imag(Δz)) / Ω
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end

function rrule(::typeof(abs), x::Real)
function abs_pullback(ΔΩ)
return (NO_FIELDS, real(ΔΩ)*sign(x))
end
return abs(x), abs_pullback
end
function rrule(::typeof(abs), z::Complex)
Ω = abs(z)
function abs_pullback(ΔΩ)
Δu = real(ΔΩ)
return (NO_FIELDS, Δu*z/Ω)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
return Ω, abs_pullback
end

## abs2
function frule((_, Δx), ::typeof(abs2), x::Real)
return abs2(x), 2x * real(Δx)
end
function frule((_, Δz), ::typeof(abs2), z::Complex)
return abs2(z), 2 * (real(z) * real(Δz) + imag(z) * imag(Δz))
end
MasonProtter marked this conversation as resolved.
Show resolved Hide resolved

function rrule(::typeof(abs2), x::Real)
function abs2_pullback(Δx)
return (NO_FIELDS, 2real(Δx)*x)
end
return abs2(x), abs2_pullback
end
function rrule(::typeof(abs2), z::Complex)
function abs2_pullback(Δf)
Δu = real(Δf)
return (NO_FIELDS, 2real(Δu)*z)
end
return abs2(z), abs2_pullback
end

## conj
function frule((_, Δz), ::typeof(conj), z::Union{Real, Complex})
return conj(z), conj(Δz)
end
function rrule(::typeof(conj), z::Union{Real, Complex})
function conj_pullback(Δf)
return (NO_FIELDS, conj(Δf))
end
return conj(z), conj_pullback
end

## angle
function frule((_, Δz), ::typeof(angle), x::Real)
Δx, Δy = reim(Δz)
return angle(x), Δy/x
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
end
function frule((_, Δz), ::typeof(angle), z::Complex)
x, y = reim(z)
Δx, Δy = reim(Δz)
return angle(z), (-y*Δx + x*Δy)/abs2(z)
end
function rrule(::typeof(angle), x::Real)
function angle_pullback(Δf)
Δu, Δv = reim(Δf)
return (NO_FIELDS, im*Δu/x)
end
MasonProtter marked this conversation as resolved.
Show resolved Hide resolved
return angle(x), angle_pullback
end
function rrule(::typeof(angle), z::Complex)
function angle_pullback(Δf)
x, y = reim(z)
Δu, Δv = reim(Δf)
return (NO_FIELDS, (-y + im*x)*Δu/abs2(z))
end
return angle(z), angle_pullback
end

# Binary functions
@scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω)
Expand Down
51 changes: 43 additions & 8 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,42 @@
# Add tests to the quote for functions with FastMath varients.
function jacobian_via_frule(f,z)
du_dx, dv_dx = reim(frule((Zero(), 1),f,z)[2])
du_dy, dv_dy = reim(frule((Zero(),im),f,z)[2])
return [
du_dx du_dy
dv_dx dv_dy
]
end
function jacobian_via_rrule(f,z)
_, pullback = rrule(f,z)
du_dx, du_dy = reim(pullback( 1)[2])
dv_dx, dv_dy = reim(pullback(im)[2])
return [
du_dx du_dy
dv_dx dv_dy
]
end

function jacobian_via_fdm(f, z::Union{Real, Complex})
fR2((x, y)) = (collect ∘ reim ∘ f)(x + im*y)
v = float([real(z)
imag(z)])
j = jacobian(central_fdm(5,1), fR2, v)[1]
if size(j) == (2,2)
j
elseif size(j) == (1, 2)
[j
false false]
else
error("Invalid Jacobian size $(size(j))")
end
end

function complex_jacobian_test(f, z)
@test jacobian_via_fdm(f, z) ≈ jacobian_via_frule(f, z)
@test jacobian_via_fdm(f, z) ≈ jacobian_via_rrule(f, z)
end

const FASTABLE_AST = quote
@testset "Trig" begin
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
Expand Down Expand Up @@ -47,13 +85,12 @@ const FASTABLE_AST = quote
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4)
test_scalar(abs, x)
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
for f ∈ (abs, abs2, angle, conj), z ∈ (-4.1-0.02im, 6.4, 3 + im)
@testset "Unary complex functions f = $f, z = $z" begin
complex_jacobian_test(f, z)
end
end
end

Expand All @@ -73,8 +110,6 @@ const FASTABLE_AST = quote
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end



@testset "sign" begin
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
test_scalar(sign, x)
Expand Down