Skip to content

Logp type should be separate from variable type - this needs AD extensions #906

Open
@mhauru

Description

@mhauru

The conversation that started in https://github.com/TuringLang/DynamicPPL.jl/pull/885/files/048178b7a8946d17fceace9b37c7e40846d50b51#r2069657078 resulted in some reflection on unflatten between me and @penelopeysm. The conclusion is that if an angel would read our code in unflatten, even after changes in #885, they would cover their eyes and cry silently, for that code is Wrong. It is Wrong because it couples the element type of the random variables given to us with the type of log probs. These are philosophically distinct: Random variables can take values in the set of sea birds, and the logpdf would still be a float. More practically, there's no reason why we shouldn't be able to have variables that are Float64s but accumulate log prob as Float32 (casting when necessary, when logpdf(dist, x) returns a Float64).

The reason why we do the Wrong thing is AD trace/dual number types. When the element type of our variables is Dual{Float64} our log probs need to become Dual{Float64} as well. We thought about this in various ways, tried to think of arguments to do something other than special case on Dual, but in the end failed. Dual just is special.

Thus, what we should do, is change the convert_eltype call in unflatten so that it is only done for a few special types, namely AD trace/dual number types. Every AD package that introduces its own Number or Real subtype that needs to be passed around like a float needs to define a method for this function, to mark it as one of the ones where log probs do need to be converted.

@penelopeysm, please expand if I didn't cover our full conclusion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions