@@ -4,14 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
44using Chairmarks: @be
55import DifferentiationInterface as DI
66using DocStringExtensions
7- using DynamicPPL:
8- Model,
9- LogDensityFunction,
10- VarInfo,
11- AbstractVarInfo,
12- link,
13- DefaultContext,
14- AbstractContext
7+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
158using LogDensityProblems: logdensity, logdensity_and_gradient
169using Random: Random, Xoshiro
1710using Statistics: median
@@ -20,12 +13,48 @@ using Test: @test
2013export ADResult, run_ad, ADIncorrectException
2114
2215"""
23- REFERENCE_ADTYPE
16+ AbstractADCorrectnessTestSetting
2417
25- Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26- it's the default AD backend used in Turing.jl.
18+ Different ways of testing the correctness of an AD backend.
2719"""
28- const REFERENCE_ADTYPE = AutoForwardDiff ()
20+ abstract type AbstractADCorrectnessTestSetting end
21+
22+ """
23+ WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24+
25+ Test correctness by comparing it against the result obtained with `adtype`.
26+
27+ `adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28+ Turing.jl.
29+ """
30+ struct WithBackend{AD<: AbstractADType } <: AbstractADCorrectnessTestSetting
31+ adtype:: AD
32+ end
33+ WithBackend () = WithBackend (AutoForwardDiff ())
34+
35+ """
36+ WithExpectedResult(
37+ value::T,
38+ grad::AbstractVector{T}
39+ ) where {T <: AbstractFloat}
40+ <: AbstractADCorrectnessTestSetting
41+
42+ Test correctness by comparing it against a known result (e.g. one obtained
43+ analytically, or one obtained with a different backend previously). Both the
44+ value of the primal (i.e. the log-density) as well as its gradient must be
45+ supplied.
46+ """
47+ struct WithExpectedResult{T<: AbstractFloat } <: AbstractADCorrectnessTestSetting
48+ value:: T
49+ grad:: AbstractVector{T}
50+ end
51+
52+ """
53+ NoTest() <: AbstractADCorrectnessTestSetting
54+
55+ Disable correctness testing.
56+ """
57+ struct NoTest <: AbstractADCorrectnessTestSetting end
2958
3059"""
3160 ADIncorrectException{T<:AbstractFloat}
84113 run_ad(
85114 model::Model,
86115 adtype::ADTypes.AbstractADType;
87- test=true ,
116+ test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend() ,
88117 benchmark=false,
89118 value_atol=1e-6,
90119 grad_atol=1e-6,
91120 varinfo::AbstractVarInfo=link(VarInfo(model), model),
92121 params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93- reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94- expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95122 verbose=true,
96123 )::ADResult
97124
@@ -143,22 +170,25 @@ Everything else is optional, and can be categorised into several groups:
143170 prep_params)`. You could then evaluate the gradient at a different set of
144171 parameters using the `params` keyword argument.
145172
146- 3. _How to specify the results to compare against._ (Only if `test=true`.)
173+ 3. _How to specify the results to compare against._
147174
148175 Once logp and its gradient has been calculated with the specified `adtype`,
149- it must be tested for correctness.
176+ it can optionally be tested for correctness. The exact way this is tested
177+ is specified in the `test` parameter.
150178
151- This can be done either by specifying `reference_adtype`, in which case logp
152- and its gradient will also be calculated with this reference in order to
153- obtain the ground truth; or by using `expected_value_and_grad`, which is a
154- tuple of `(logp, gradient)` that the calculated values must match. The
155- latter is useful if you are testing multiple AD backends and want to avoid
156- recalculating the ground truth multiple times.
179+ There are several options for this:
157180
158- The default reference backend is ForwardDiff. If none of these parameters are
159- specified, ForwardDiff will be used to calculate the ground truth.
181+ - You can explicitly specify the correct value using
182+ [`WithExpectedResult()`](@ref).
183+ - You can compare against the result obtained with a different AD backend
184+ using [`WithBackend(adtype)`](@ref).
185+ - You can disable testing by passing [`NoTest()`](@ref).
186+ - The default is to compare against the result obtained with ForwardDiff,
187+ i.e. `WithBackend(AutoForwardDiff())`.
188+ - `test=false` and `test=true` are synonyms for
189+ `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
160190
161- 4. _How to specify the tolerances._ (Only if `test=true` .)
191+ 4. _How to specify the tolerances._ (Only if testing is enabled .)
162192
163193 The tolerances for the value and gradient can be set using `value_atol` and
164194 `grad_atol`. These default to 1e-6.
@@ -180,48 +210,57 @@ thrown as-is.
180210function run_ad (
181211 model:: Model ,
182212 adtype:: AbstractADType ;
183- test:: Bool = true ,
213+ test:: Union{AbstractADCorrectnessTestSetting, Bool} = WithBackend () ,
184214 benchmark:: Bool = false ,
185215 value_atol:: AbstractFloat = 1e-6 ,
186216 grad_atol:: AbstractFloat = 1e-6 ,
187217 varinfo:: AbstractVarInfo = link (VarInfo (model), model),
188218 params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
189- reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
190- expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
191219 verbose= true ,
192220):: ADResult
221+ # Convert Boolean `test` to an AbstractADCorrectnessTestSetting
222+ if test isa Bool
223+ test = test ? WithBackend () : NoTest ()
224+ end
225+
226+ # Extract parameters
193227 if isnothing (params)
194228 params = varinfo[:]
195229 end
196230 params = map (identity, params) # Concretise
197231
232+ # Calculate log-density and gradient with the backend of interest
198233 verbose && @info " Running AD on $(model. f) with $(adtype) \n "
199234 verbose && println (" params : $(params) " )
200235 ldf = LogDensityFunction (model, varinfo; adtype= adtype)
201-
202236 value, grad = logdensity_and_gradient (ldf, params)
237+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203238 grad = collect (grad)
204239 verbose && println (" actual : $((value, grad)) " )
205240
206- if test
207- # Calculate ground truth to compare against
208- value_true, grad_true = if expected_value_and_grad === nothing
209- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
210- logdensity_and_gradient (ldf_reference, params)
211- else
212- expected_value_and_grad
241+ # Test correctness
242+ if test isa NoTest
243+ value_true = nothing
244+ grad_true = nothing
245+ else
246+ # Get the correct result
247+ if test isa WithExpectedResult
248+ value_true = test. value
249+ grad_true = test. grad
250+ elseif test isa WithBackend
251+ ldf_reference = LogDensityFunction (model, varinfo; adtype= test. adtype)
252+ value_true, grad_true = logdensity_and_gradient (ldf_reference, params)
253+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
254+ grad_true = collect (grad_true)
213255 end
256+ # Perform testing
214257 verbose && println (" expected : $((value_true, grad_true)) " )
215- grad_true = collect (grad_true)
216-
217258 exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
218259 isapprox (value, value_true; atol= value_atol) || exc ()
219260 isapprox (grad, grad_true; atol= grad_atol) || exc ()
220- else
221- value_true = nothing
222- grad_true = nothing
223261 end
224262
263+ # Benchmark
225264 time_vs_primal = if benchmark
226265 primal_benchmark = @be (ldf, params) logdensity (_[1 ], _[2 ])
227266 grad_benchmark = @be (ldf, params) logdensity_and_gradient (_[1 ], _[2 ])
0 commit comments