-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Conditional sampling #640
Comments
Thanks for raising the issue! I agree completely - I have been discussing with some other users that sees similar need (e.g., https://twitter.com/ML_deep/status/1188387178507694081?s=20) |
I'd love it. Should work with all flavors of JD, ideally. I know it's just
sugar around a lambda, but it spells out intent and makes code read more
nicely.
…On Tue, Nov 5, 2019, 5:26 PM Junpeng Lao ***@***.***> wrote:
Thanks for raising the issue! I agree completely - I have been discussing
with some other users that sees similar need (e.g.,
https://twitter.com/ML_deep/status/1188387178507694081?s=20)
@csuter <https://github.com/csuter> @brianwa84
<https://github.com/brianwa84> @jvdillon <https://github.com/jvdillon>
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#640?email_source=notifications&email_token=AFJFSI7DKVAMT74CIE2XWSTQSHXKXA5CNFSM4JJJDPZKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEDER3WA#issuecomment-550051288>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AFJFSI5IWZM6LD3HQGPYYADQSHXKXANCNFSM4JJJDPZA>
.
|
Fwiw, I strongly urge against this sugar. (I'll even beg if it helps.) To me, the connection between closures and unnormalized densities is the second most elegant part of TFP. I also disagree it makes code more readable. While the lambda spells it out clearly how the thing is working, the cond_lp seems to me only to obfuscate. Also, the created object is not a distribution as the density is unnormalized. Finally, the simplified version isn't really a fair comparison since the cast and expand_dims would need to be done there. A better comparison would be: Im happy to go down a long list of other reasons, but the tl;dr is that I claim this sugar only feels like it solves a problem but actually adds cognitive burden (yet another thing to learn), runs the risk of making a user think its required, and obfuscates what is otherwise a one-liner. If we applied konmari to software design, I claim we'd be quite happy with lambdas. |
If the only thing this sugar does is construct a conditioned unnormalized log-probability function, it is not super useful. However, if it simultaneously handles other properties of the distribution (e.g. the shapes, dtypes etc), then it becomes more compelling. I've had good experience using something I call jd = JointDistribution(...)
jd.event_shape == [(1, 1), (2,), (3,)]
jd.dtype == [tf.int32, tf.float32, tf.float64]
jdp = JointDistributionPosterior(jd, conditioning=(None, tf.zeros([2]), None))
jdp.event_shape == [(1, 1), (3,)]
jdp.dtype == [tf.int32, tf.float64]
jdp.unnormalized_log_prob(
tf.nest.map_structure(lambda s, d: tf.zeros(s, dtype=d),
jdp.event_shape, jdp.dtype)) == tf.Tensor It's very easy to write something like that yourself even if we never add it to TFP. |
I think the problem current is that, to construct an (unnormalized) conditional posterior for inference, the APIs are quite inconsistent for different JD* flavor, and it also require user to understand the call signature: init_state = [var1, var2, ...] # <== a list of tensors
lp = lambda x: mdl_jdseq.log_prob(
x + [observed])
lp(init_state) # <== this works but not when you plug it into mcmc.sample,
# which means user will get error downstream (much) later.
lp = lambda *x: mdl_jdseq.log_prob(
x + (observed, ))
lp = lambda *x: mdl_jdseq.log_prob(
list(x) + [observed]) # <== Another alternative, which arguably the "right"
# version as mdl_jdseq.sample([...]) returns a list.
# So by the contract of jd.log_prob(jd.sample([...]))
# the input to jd.log_prob should also be a list.
lp(*init_state)
# Not sure about what is the best practice here, as there are many way to
# construct a dict-like object for jd_named.log_prob - Nonetheless additional
# user input is needed here
import collections
Model = collections.namedtuple('Model', [...])
lp = lambda *x: mdl_jdname.log_prob(
Model(*x, observed))
lp(*init_state)
lp = lambda x: mdl_jdcoroutine.log_prob(
x + [observed])
lp(init_state) # <== this works but not when you plug it into mcmc.sample,
# which means user will get error downstream (much) later.
lp = lambda *x: mdl_jdcoroutine.log_prob(
x + (observed, )) # <== the canonical version as coroutine jd samples are tuple
lp = lambda *x: mdl_jdcoroutine.log_prob(
list(x) + [observed])
lp(*init_state) A syntactic sugar would make sure it is consistence for all JD* If we don't introduce an additional API for this, we should definitely make this more clear in doc-strings and documentations. |
BTW, all the code above is basing on the assumption that the last node(s) is the observed. |
+1 to better docstrings. As for the different call styles, I see this difference as one of the key points of having different JD* flavors. The reason for the current style is that we wanted to preserve the |
In many TFP bayesian use cases, it's very helpful to specify a joint distribution using a
JointDistribution*
object - it makes sampling (for prior-predictive checks) straightforward (sorry), and exposes a log_prob function necessary for the sampler. However, since many (most?) of these cases involve some sort of conditioning, we end up writing a function closure which is very confusing and possibly error prone (mostly shape errors, but also type errors):Exposing some sort of conditioning method, instead, could be amazing.
For example, for a JointDistributionSequential (which is represented by a list), perhaps something along these lines:
?
Thanks in advance!
The text was updated successfully, but these errors were encountered: