-
Notifications
You must be signed in to change notification settings - Fork 15
/
testers.jl
286 lines (249 loc) · 11.7 KB
/
testers.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
at input point `z` to confirm that there are correct `frule` and `rrule`s provided.
# Arguments
- `f`: Function for which the `frule` and `rrule` should be tested.
- `z`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
`fkwargs` are passed to `f` as keyword arguments.
If `check_inferred=true`, then the type-stability of the `frule` and `rrule` are checked.
All remaining keyword arguments are passed to `isapprox`.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), check_inferred=true, kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
rule_test_kwargs = (; rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, check_inferred=check_inferred, kwargs...)
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
@testset "test_scalar: $f at $z" begin
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)
# test jacobian using forward mode
Δx = one(z)
@testset "with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
test_frule(f, z ⊢ Δx; rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
ḟ = rand_tangent(f)
_, real_tangent = frule((ḟ, real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((ḟ, Δx), f, z; fkwargs...)
test_approx(real_tangent, embedded_tangent; isapprox_kwargs...)
end
end
if z isa Complex
Δy = one(z) * im
@testset "with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
test_frule(f, z ⊢ Δy; rule_test_kwargs...)
end
end
# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δu, rule_test_kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
_, back = rrule(f, z)
_, real_cotangent = back(real(Δu))
_, embedded_cotangent = back(Δu)
test_approx(real_cotangent, embedded_cotangent; isapprox_kwargs...)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δv, rule_test_kwargs...)
end
end
end # top-level testset
end
"""
test_frule([config::RuleConfig,] f, args..; kwargs...)
# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
# Keyword Arguments
- `output_tangent` tangent to test accumulation of derivatives against
should be a differential for the output of `f`. Is set automatically if not provided.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
`frule`). Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability of the `frule` is checked,
as long as `f` is itself inferrable.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_frule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_frule(config, args...; kwargs...)
end
function test_frule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
fdm=_fdm,
frule_f=ChainRulesCore.frule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
atol::Real=1e-9,
kwargs...,
)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
# and define a helper closure
call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
@testset "test_frule: $f on $(_string_typeof(args))" begin
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
tangents = tangent.(primals_and_tangents)
if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
_test_inferred(frule_f, deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
end
res = call_on_copy(frule_f, config, tangents, primals...)
res === nothing && throw(MethodError(frule_f, typeof(primals)))
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = call_on_copy(primals...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)
# Correctness testing via finite differencing.
is_ignored = isa.(tangents, NoTangent)
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)
end # top-level testset
end
"""
test_rrule([config::RuleConfig,] f, args...; kwargs...)
# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
# Keyword Arguments
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
should be a differential for the output of `f`. Is set automatically if not provided.
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
output tangent to the pullback returns the same result.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_rrule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_rrule(config, args...; kwargs...)
end
function test_rrule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
check_thunked_output_tangent=true,
fdm=_fdm,
rrule_f=ChainRulesCore.rrule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
atol::Real=1e-9,
kwargs...,
)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
# and define helper closure over fkwargs
call(f, xs...) = f(xs...; fkwargs...)
@testset "test_rrule: $f on $(_string_typeof(args))" begin
# Check correctness of evaluation.
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
accum_cotangents = tangent.(primals_and_tangents)
if check_inferred && _is_inferrable(primals...; fkwargs...)
_test_inferred(rrule_f, config, primals...; fkwargs...)
end
res = rrule_f(config, primals...; fkwargs...)
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
check_inferred && _test_inferred(pullback, ȳ)
ad_cotangents = pullback(ȳ)
@test_msg(
"The pullback must return a Tuple (∂self, ∂args...)",
ad_cotangents isa Tuple
)
@test_msg(
"The pullback should return 1 cotangent for the primal and each primal input.",
length(ad_cotangents) == length(primals)
)
# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
accum_cotangents, ad_cotangents, fd_cotangents
)
if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
ad_cotangent isa ZeroTangent && error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments.",
)
@test ad_cotangent isa NoTangent # we said it wasn't differentiable.
else
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
end
end
if check_thunked_output_tangent
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
end
end # top-level testset
end
"""
@maybe_inferred [Type] f(...)
Like `@inferred`, but does not check the return type if tests are run as part of PkgEval or
if the environment variable `CHAINRULES_TEST_INFERRED` is set to `false`.
"""
macro maybe_inferred(ex...)
inferred = Expr(:macrocall, GlobalRef(Test, Symbol("@inferred")), __source__, ex...)
return :(TEST_INFERRED[] ? $(esc(inferred)) : $(esc(last(ex))))
end
"""
_test_inferred(f, args...; kwargs...)
Simple wrapper for [`@maybe_inferred f(args...: kwargs...)`](@ref `@maybe_inferred`), avoiding the type-instability in not
knowing how many `kwargs` there are.
"""
function _test_inferred(f, args...; kwargs...)
if isempty(kwargs)
@maybe_inferred f(args...)
else
@maybe_inferred f(args...; kwargs...)
end
end
"""
_is_inferrable(f, args...; kwargs...) -> Bool
Return whether the return type of `f(args...; kwargs...)` is inferrable.
"""
function _is_inferrable(f, args...; kwargs...)
try
_test_inferred(f, args...; kwargs...)
return true
catch ErrorException
return false
end
end