|
52 | 52 | test_scalar(acotd, 1/x) |
53 | 53 | end |
54 | 54 | @testset "Multivariate" begin |
55 | | - x, y = rand(2) |
56 | 55 | @testset "atan2" begin |
57 | 56 | # https://en.wikipedia.org/wiki/Atan2 |
| 57 | + x, y = rand(2) |
58 | 58 | ratan = atan(x, y) |
59 | 59 | u = x^2 + y^2 |
60 | 60 | datan = y/u - 2x/u |
|
71 | 71 | end |
72 | 72 |
|
73 | 73 | @testset "sincos" begin |
74 | | - rsincos = sincos(x) |
75 | | - dsincos = cos(x) - 2sin(x) |
76 | | - |
77 | | - r, pushforward = frule(sincos, x) |
78 | | - @test r === rsincos |
79 | | - df1, df2 = pushforward(NamedTuple(), 1) |
80 | | - @test df1 + 2df2 === dsincos |
81 | | - |
82 | | - r, pullback = rrule(sincos, x) |
83 | | - @test r === rsincos |
84 | | - ds, df = pullback(1, 2) |
85 | | - @test df === dsincos |
86 | | - @test ds === NO_FIELDS |
| 74 | + x, Δx, x̄ = randn(3) |
| 75 | + Δz = (randn(), randn()) |
| 76 | + |
| 77 | + frule_test(sincos, (x, Δx)) |
| 78 | + rrule_test(sincos, Δz, (x, x̄)) |
87 | 79 | end |
88 | 80 | end |
89 | 81 | end # Trig |
|
114 | 106 | end |
115 | 107 |
|
116 | 108 | @testset "Unary complex functions" begin |
117 | | - for x in (-6, rand.((Float32, Float64, Complex{Float32}, Complex{Float64}))...) |
118 | | - rtol = x isa Complex{Float32} ? 1e-6 : 1e-9 |
119 | | - test_scalar(real, x; rtol=rtol) |
120 | | - test_scalar(imag, x; rtol=rtol) |
| 109 | + for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im) |
| 110 | + test_scalar(real, x) |
| 111 | + test_scalar(imag, x) |
121 | 112 |
|
122 | | - test_scalar(abs, x; rtol=rtol) |
123 | | - test_scalar(hypot, x; rtol=rtol) |
| 113 | + test_scalar(abs, x) |
| 114 | + test_scalar(hypot, x) |
124 | 115 |
|
125 | | - test_scalar(angle, x; rtol=rtol) |
126 | | - test_scalar(abs2, x; rtol=rtol) |
127 | | - test_scalar(conj, x; rtol=rtol) |
| 116 | + test_scalar(angle, x) |
| 117 | + test_scalar(abs2, x) |
| 118 | + test_scalar(conj, x) |
128 | 119 | end |
129 | 120 | end |
130 | 121 |
|
|
146 | 137 | test_accumulation(rand(2, 5), dy) |
147 | 138 | end |
148 | 139 |
|
149 | | - @testset "hypot(x, y)" begin |
| 140 | + @testset "binary trig ($f)" for f in (hypot, atan) |
150 | 141 | rng = MersenneTwister(123456) |
151 | | - x, Δx, x̄ = randn(rng, 3) |
| 142 | + x, Δx, x̄ = 10randn(rng, 3) |
152 | 143 | y, Δy, ȳ = randn(rng, 3) |
153 | 144 | Δz = randn(rng) |
154 | 145 |
|
155 | | - frule_test(hypot, (x, Δx), (y, Δy)) |
156 | | - rrule_test(hypot, Δz, (x, x̄), (y, ȳ)) |
| 146 | + frule_test(f, (x, Δx), (y, Δy)) |
| 147 | + rrule_test(f, Δz, (x, x̄), (y, ȳ)) |
157 | 148 | end |
158 | 149 |
|
159 | 150 | @testset "identity" begin |
|
166 | 157 | test_scalar(one, x) |
167 | 158 | test_scalar(zero, x) |
168 | 159 | end |
| 160 | + |
| 161 | + @testset "sign" begin |
| 162 | + @testset "at points" for x in (-1.1, -1.1, 0.5, 100) |
| 163 | + test_scalar(sign, x) |
| 164 | + end |
| 165 | + |
| 166 | + @testset "Zero over the point discontinuity" begin |
| 167 | + # Can't do finite differencing because we are lying |
| 168 | + # following the subgradient convention. |
| 169 | + |
| 170 | + _, pb = rrule(sign, 0.0) |
| 171 | + _, x̄ = pb(10.5) |
| 172 | + @test extern(x̄) == 0 |
| 173 | + |
| 174 | + _, pf = frule(sign, 0.0) |
| 175 | + ẏ = pf(NamedTuple(), 10.5) |
| 176 | + @test extern(ẏ) == 0 |
| 177 | + end |
| 178 | + end |
169 | 179 | end |
0 commit comments