Skip to content

Commit 74396b4

Browse files
committed
Test ChainRules integration directly
Update test/chainrules.jl Update test/chainrules.jl
1 parent c89c5d0 commit 74396b4

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

test/chainrules.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using Zygote, Test, ChainRules
2+
3+
const cr_inner_demo_rrule_hitcount = Ref(0)
4+
const cr_inner_demo_pullback_hitcount = Ref(0)
5+
cr_inner_demo(x) = 5x
6+
function ChainRules.rrule(::typeof(cr_inner_demo), x)
7+
cr_inner_demo_rrule_hitcount[] += 1
8+
function cr_inner_demo_pullback(Δx)
9+
cr_inner_demo_pullback_hitcount[] += 1
10+
return ChainRules.NO_FIELDS, 5.0*Δx
11+
end
12+
return cr_inner_demo(x), cr_inner_demo_pullback
13+
end
14+
15+
function cr_outer_demo(x)
16+
2 + 10cr_inner_demo(x)
17+
end
18+
19+
@testset "ChainRules Integration" begin
20+
@testset "gradient inner" begin
21+
cr_inner_demo_rrule_hitcount[] = 0
22+
cr_inner_demo_pullback_hitcount[] = 0
23+
@test (5.0,) == gradient(cr_inner_demo, 11)
24+
@test cr_inner_demo_rrule_hitcount[] == 1
25+
@test cr_inner_demo_pullback_hitcount[] == 1
26+
end
27+
28+
@testset "gradient outer" begin
29+
cr_inner_demo_rrule_hitcount[] = 0
30+
cr_inner_demo_pullback_hitcount[] = 0
31+
@test (50.0,) == gradient(cr_outer_demo, 11)
32+
@test cr_inner_demo_rrule_hitcount[] == 1
33+
@test cr_inner_demo_pullback_hitcount[] == 1
34+
end
35+
36+
@testset "pullback inner" begin
37+
cr_inner_demo_rrule_hitcount[] = 0
38+
cr_inner_demo_pullback_hitcount[] = 0
39+
y, pb = pullback(cr_inner_demo, 11)
40+
@test y == 55
41+
@test cr_inner_demo_rrule_hitcount[] == 1
42+
@test cr_inner_demo_pullback_hitcount[] == 0
43+
@test pb(1)==(5.0,);
44+
@test pb(2)==(10.0,);
45+
@test pb(3)==(15.0,);
46+
@test cr_inner_demo_pullback_hitcount[] == 3
47+
@test cr_inner_demo_rrule_hitcount[] == 1
48+
end
49+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ end
1515
include("structures.jl")
1616
end
1717

18+
@info "Testing ChainRules integration"
19+
20+
@testset "ChainRules" begin
21+
include("chainrules.jl")
22+
end
23+
1824
@info "Running Gradient Checks"
1925

2026
@testset "Gradients" begin

0 commit comments

Comments
 (0)