Description
The conversation that started in https://github.com/TuringLang/DynamicPPL.jl/pull/885/files/048178b7a8946d17fceace9b37c7e40846d50b51#r2069657078 resulted in some reflection on unflatten
between me and @penelopeysm. The conclusion is that if an angel would read our code in unflatten
, even after changes in #885, they would cover their eyes and cry silently, for that code is Wrong. It is Wrong because it couples the element type of the random variables given to us with the type of log probs. These are philosophically distinct: Random variables can take values in the set of sea birds, and the logpdf would still be a float. More practically, there's no reason why we shouldn't be able to have variables that are Float64
s but accumulate log prob as Float32
(casting when necessary, when logpdf(dist, x)
returns a Float64
).
The reason why we do the Wrong thing is AD trace/dual number types. When the element type of our variables is Dual{Float64}
our log probs need to become Dual{Float64}
as well. We thought about this in various ways, tried to think of arguments to do something other than special case on Dual
, but in the end failed. Dual
just is special.
Thus, what we should do, is change the convert_eltype
call in unflatten
so that it is only done for a few special types, namely AD trace/dual number types. Every AD package that introduces its own Number
or Real
subtype that needs to be passed around like a float needs to define a method for this function, to mark it as one of the ones where log probs do need to be converted.
@penelopeysm, please expand if I didn't cover our full conclusion.