-
Notifications
You must be signed in to change notification settings - Fork 15
/
testers.jl
330 lines (289 loc) · 13.4 KB
/
testers.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
at input point `z` to confirm that there are correct `frule` and `rrule`s provided.
# Arguments
- `f`: function for which the `frule` and `rrule` should be tested.
- `z`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
# Keyword Arguments
- `fdm`: the finite differencing method to use.
- `fkwargs` are passed to `f` as keyword arguments.
- If `check_inferred=true`, then the inferrability (type-stability) of the `frule` and `rrule` are checked.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), check_inferred=true, kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
rule_test_kwargs = (; rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, check_inferred=check_inferred, kwargs...)
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
@testset "test_scalar: $f at $z" begin
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)
# test jacobian using forward mode
Δx = one(z)
@testset "with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
test_frule(f, z ⊢ Δx; rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
ḟ = rand_tangent(f)
_, real_tangent = frule((ḟ, real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((ḟ, Δx), f, z; fkwargs...)
test_approx(real_tangent, embedded_tangent; isapprox_kwargs...)
end
end
if z isa Complex
Δy = one(z) * im
@testset "with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
test_frule(f, z ⊢ Δy; rule_test_kwargs...)
end
end
# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δu, rule_test_kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
_, back = rrule(f, z)
_, real_cotangent = back(real(Δu))
_, embedded_cotangent = back(Δu)
test_approx(real_cotangent, embedded_cotangent; isapprox_kwargs...)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δv, rule_test_kwargs...)
end
end
end # top-level testset
end
"""
test_frule([config::RuleConfig,] f, args..; kwargs...)
# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: function for which the `frule` should be tested. Its tangent can be provided using `f ⊢ ḟ`.
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
- `x`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
- `ẋ`: differential w.r.t. `x`; will be generated automatically if not provided.
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
# Keyword Arguments
- `output_tangent`: tangent against which to test accumulation of derivatives.
Should be a differential for the output of `f`. Is set automatically if not provided.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `frule_f=frule`: function with an `frule`-like API that is tested (defaults to
`frule`). Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability (type-stability) of the `frule` is checked,
as long as `f` is itself inferrable.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_frule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_frule(config, args...; kwargs...)
end
function test_frule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
fdm=_fdm,
frule_f=ChainRulesCore.frule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
atol::Real=1e-9,
kwargs...,
)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
# and define a helper closure
call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
@testset "test_frule: $f on $(_string_typeof(args))" begin
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
tangents = tangent.(primals_and_tangents)
if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
_test_inferred(frule_f, deepcopy(config), deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
end
res = call_on_copy(frule_f, config, tangents, primals...)
res === nothing && throw(MethodError(frule_f, typeof(primals)))
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
Ω_ad, dΩ_ad = res
Ω = call_on_copy(primals...)
test_approx(Ω_ad, Ω; isapprox_kwargs...)
# Correctness testing via finite differencing.
is_ignored = isa.(tangents, NoTangent)
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
_test_add!!_behaviour(acc, dΩ_ad; isapprox_kwargs...)
end # top-level testset
end
"""
test_rrule([config::RuleConfig,] f, args...; kwargs...)
# Arguments
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
- `f`: function for which the `rrule` should be tested. Its tangent can be provided using `f ⊢ f̄`.
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
- `x`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
- `x̄`: currently accumulated cotangent; will be generated automatically if not provided.
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
# Keyword Arguments
- `output_tangent`: the seed to propagate backward for testing (technically a cotangent).
should be a differential for the output of `f`. Is set automatically if not provided.
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
output tangent to the pullback returns the same result.
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
- `rrule_f=rrule`: function with an `rrule`-like API that is tested (defaults to `rrule`).
Used for testing gradients from AD systems.
- If `check_inferred=true`, then the inferrability (type-stability) of the `rrule` is checked
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
- `fkwargs` are passed to `f` as keyword arguments.
- All remaining keyword arguments are passed to `isapprox`.
"""
function test_rrule(args...; kwargs...)
config = ChainRulesTestUtils.ADviaRuleConfig()
test_rrule(config, args...; kwargs...)
end
function test_rrule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
check_thunked_output_tangent=true,
fdm=_fdm,
rrule_f=ChainRulesCore.rrule,
check_inferred::Bool=true,
fkwargs::NamedTuple=NamedTuple(),
rtol::Real=1e-9,
atol::Real=1e-9,
kwargs...,
)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
# and define helper closure over fkwargs
call(f, xs...) = f(xs...; fkwargs...)
@testset "test_rrule: $f on $(_string_typeof(args))" begin
# Check correctness of evaluation.
primals_and_tangents = auto_primal_and_tangent.((f, args...))
primals = primal.(primals_and_tangents)
accum_cotangents = tangent.(primals_and_tangents)
if check_inferred && _is_inferrable(primals...; fkwargs...)
_test_inferred(rrule_f, config, primals...; fkwargs...)
end
res = rrule_f(config, primals...; fkwargs...)
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
check_inferred && _test_inferred(pullback, ȳ)
ad_cotangents = pullback(ȳ)
@test_msg(
"The pullback must return a Tuple (∂self, ∂args...)",
ad_cotangents isa Tuple
)
@test_msg(
"The pullback should return 1 cotangent for the primal and each primal input.",
length(ad_cotangents) == length(primals)
)
# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
end
if check_thunked_output_tangent
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
check_inferred && _test_inferred(pullback, @thunk(ȳ))
end
end # top-level testset
end
"""
@maybe_inferred [Type] f(...)
Like `@inferred`, but does not check the return type if tests are run as part of PkgEval or
if the environment variable `CHAINRULES_TEST_INFERRED` is set to `false`.
"""
macro maybe_inferred(ex...)
inferred = Expr(:macrocall, GlobalRef(Test, Symbol("@inferred")), __source__, ex...)
return :(TEST_INFERRED[] ? $(esc(inferred)) : $(esc(last(ex))))
end
"""
_test_inferred(f, args...; kwargs...)
Simple wrapper for [`@maybe_inferred f(args...: kwargs...)`](@ref `@maybe_inferred`), avoiding the type-instability in not
knowing how many `kwargs` there are.
"""
function _test_inferred(f, args...; kwargs...)
if isempty(kwargs)
@maybe_inferred f(args...)
else
@maybe_inferred f(args...; kwargs...)
end
end
"""
_is_inferrable(f, args...; kwargs...) -> Bool
Return whether the return type of `f(args...; kwargs...)` is inferrable.
"""
function _is_inferrable(f, args...; kwargs...)
try
_test_inferred(f, args...; kwargs...)
return true
catch ErrorException
return false
end
end
"""
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.
If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.
# Keyword arguments
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
its content can be inferred.
- All remaining keyword arguments are passed to `isapprox`.
"""
function _test_cotangent(
accum_cotangent,
ad_cotangent,
fd_cotangent;
check_inferred=true,
kwargs...,
)
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
end
# we marked the argument as non-differentiable
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
@test ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments."
)
end
function _test_cotangent(
::NoTangent,
ad_cotangent::ChainRulesCore.NotImplemented,
::NoTangent;
kwargs...,
)
# this situation can occur if a cotangent is not implemented and
# the default `rand_tangent` is `NoTangent`: e.g. due to having no fields
# the `@test_broken` below should tell them that there is an easy implementation for
# this case of `NoTangent()` (`@test_broken false` would be less useful!)
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
error("cotangent obtained with finite differencing has to be NoTangent()")
end