Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Conditional sampling #640

Open
adamhaber opened this issue Nov 5, 2019 · 7 comments
Open

Feature request: Conditional sampling #640

adamhaber opened this issue Nov 5, 2019 · 7 comments

Comments

@adamhaber
Copy link

In many TFP bayesian use cases, it's very helpful to specify a joint distribution using a JointDistribution* object - it makes sampling (for prior-predictive checks) straightforward (sorry), and exposes a log_prob function necessary for the sampler. However, since many (most?) of these cases involve some sort of conditioning, we end up writing a function closure which is very confusing and possibly error prone (mostly shape errors, but also type errors):

lp = lambda *x: model.log_prob(list(x) + [tf.cast(df['y'],tf.float32)[tf.newaxis, ...]])

Exposing some sort of conditioning method, instead, could be amazing.

For example, for a JointDistributionSequential (which is represented by a list), perhaps something along these lines:

posterior_lp = model.cond_lp(conditioned_var_idx = ..., condition_on = ...)

?

Thanks in advance!

@junpenglao
Copy link
Contributor

Thanks for raising the issue! I agree completely - I have been discussing with some other users that sees similar need (e.g., https://twitter.com/ML_deep/status/1188387178507694081?s=20)
@csuter @brianwa84 @jvdillon

@brianwa84
Copy link
Contributor

brianwa84 commented Nov 5, 2019 via email

@jvdillon
Copy link
Contributor

jvdillon commented Nov 5, 2019

Fwiw, I strongly urge against this sugar. (I'll even beg if it helps.) To me, the connection between closures and unnormalized densities is the second most elegant part of TFP.

I also disagree it makes code more readable. While the lambda spells it out clearly how the thing is working, the cond_lp seems to me only to obfuscate. Also, the created object is not a distribution as the density is unnormalized. Finally, the simplified version isn't really a fair comparison since the cast and expand_dims would need to be done there. A better comparison would be:
lp = lambda *x: model.log_prob(x + (df[y'],))
and that reads purdy darn nicely to me!

Im happy to go down a long list of other reasons, but the tl;dr is that I claim this sugar only feels like it solves a problem but actually adds cognitive burden (yet another thing to learn), runs the risk of making a user think its required, and obfuscates what is otherwise a one-liner. If we applied konmari to software design, I claim we'd be quite happy with lambdas.

@SiegeLordEx
Copy link
Member

If the only thing this sugar does is construct a conditioned unnormalized log-probability function, it is not super useful. However, if it simultaneously handles other properties of the distribution (e.g. the shapes, dtypes etc), then it becomes more compelling. I've had good experience using something I call JointDistributionPosterior which takes the conditioning as a constructor arg, and produces a distribution-like object, e.g.:

jd = JointDistribution(...)
jd.event_shape == [(1, 1), (2,), (3,)]
jd.dtype == [tf.int32, tf.float32, tf.float64]

jdp = JointDistributionPosterior(jd, conditioning=(None, tf.zeros([2]), None))
jdp.event_shape == [(1, 1), (3,)]
jdp.dtype == [tf.int32, tf.float64]
jdp.unnormalized_log_prob(
  tf.nest.map_structure(lambda s, d: tf.zeros(s, dtype=d),
    jdp.event_shape, jdp.dtype)) == tf.Tensor

It's very easy to write something like that yourself even if we never add it to TFP.

@junpenglao
Copy link
Contributor

junpenglao commented Nov 6, 2019

I think the problem current is that, to construct an (unnormalized) conditional posterior for inference, the APIs are quite inconsistent for different JD* flavor, and it also require user to understand the call signature:

init_state = [var1, var2, ...]  # <== a list of tensors

lp = lambda x: mdl_jdseq.log_prob(
    x + [observed])
lp(init_state)  # <== this works but not when you plug it into mcmc.sample,
                # which means user will get error downstream (much) later.

lp = lambda *x: mdl_jdseq.log_prob(
    x + (observed, ))
lp = lambda *x: mdl_jdseq.log_prob(
    list(x) + [observed])  # <== Another alternative, which arguably the "right"
                           # version as mdl_jdseq.sample([...]) returns a list.
                           # So by the contract of jd.log_prob(jd.sample([...]))
                           # the input to jd.log_prob should also be a list.
lp(*init_state)

# Not sure about what is the best practice here, as there are many way to
# construct a dict-like object for jd_named.log_prob - Nonetheless additional
# user input is needed here
import collections
Model = collections.namedtuple('Model', [...])
lp = lambda *x: mdl_jdname.log_prob(
    Model(*x, observed))
lp(*init_state)

lp = lambda x: mdl_jdcoroutine.log_prob(
    x + [observed])
lp(init_state)  # <== this works but not when you plug it into mcmc.sample,
                # which means user will get error downstream (much) later.
lp = lambda *x: mdl_jdcoroutine.log_prob(
    x + (observed, ))  # <== the canonical version as coroutine jd samples are tuple
lp = lambda *x: mdl_jdcoroutine.log_prob(
    list(x) + [observed])
lp(*init_state)

A syntactic sugar would make sure it is consistence for all JD*

If we don't introduce an additional API for this, we should definitely make this more clear in doc-strings and documentations.

@junpenglao
Copy link
Contributor

BTW, all the code above is basing on the assumption that the last node(s) is the observed.

@jvdillon
Copy link
Contributor

jvdillon commented Nov 14, 2019

+1 to better docstrings.
I agree there's a learning curve here, but I feel this learning curve is "worth it" since the current approach ensures the unnormalized posterior is merely a thin accessor to the full joint (this being the inferential base). Furthermore, by not codifying this accessor we emphasize that all downstream inference logic is agnostic--any function will suffice.

As for the different call styles, I see this difference as one of the key points of having different JD* flavors. The reason for the current style is that we wanted to preserve the d.log_prob(d.sample()) pattern yet also have d.sample() be interpretable wrt the model as supplied to __init__. If it turns out this difference is more pain than benefit, Id rather see us change the JointDistribution* than build new sugar on top.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants