-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jitter to Cholesky factorization in Gaussian ops #3151
Conversation
pyro/ops/tensor_utils.py
Outdated
return x.sqrt() | ||
|
||
# Add adaptive jitter. | ||
x = x.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure you need/want clones? fwiw i do this in millipede, which is similar to what's done in gpytorch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup this .clone()
is needed because we're mutating the matrix. Nice, I'll rename this to safe_cholesky()
as in millipede.
Thanks, I did try various gpytorch-style lazy tactics that avoid adding jitter until a failure has occurred. I found that since the Cholesky error happens only late in the filtering process, by that point the filter state had already been corrupted by nearly-singular matrices that just barely didn't trigger an error. The best solution I've found so far is to add a tiny amount of noise to all matrices so that error doesn't build up during filtering. Other solutions include using svd
or pinv
or ldl_factor
, but they were more expensive.
Note the core piece of linear algebra is in Gaussian.marginalize() which is repeatedly called in the filter pass of sequential_gaussian_filter_sample(). It's just the blockwise symmetric matrix inverse formula:
# in Gaussian.marginalize():
P_aa = self.precision[..., a, a]
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = safe_cholesky(P_bb) # Note if we add a little jitter here...
P_a = triangular_solve(P_ba, P_b, upper=False) # ...then this is smaller...
P_at = P_a.transpose(-1, -2)
precision = P_aa - matmul(P_at, P_a) # ...so this is even better conditioned.
This code has the nice property that if we add a little bit of jitter before Cholesky factorizing, the next precision matrix becomes only better-conditioned. Empirically this allowed me to get away with much smaller jitter than was needed if I waited for an error to occur.
BTW it looks like you could speed up millipede by switching from try: c = cholesky()
to the faster c, info = cholesky_ex(); if not info.any(): return c
, which is used in gpytorch. The only reason I'm not using cholesky_ex()
here is that I found the decision-based version was too unstable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. but CHOLESKY_JITTER
isn't actually easily toggle-able is it?
pyro/ops/tensor_utils.py
Outdated
if x.size(-1) == 1: | ||
if CHOLESKY_JITTER: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you intend to clamp by (CHOLESKY_JITTER * finfo(x.dtype).eps) ** 2
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My intention was to scale by about finfo(x.dtype).eps * x.max()
so that the jitter was just barely detectable by the largest matrix entry before Cholesky factorizing. That way if we set RELATIVE_CHOLESKY_JITTER = 1/2
then jitter will only affect matrix entries less than half the size of the max. And it kindof makes sense to me that each additional bit of precision would mean we would need to add half as much jitter, thus jitter would be proportional to finfo(x.dtype).eps
. Mostly the proportional think helps us keep a constant RELATIVE_CHOLESKY_JITTER
across float32 and float64.
What's your intuition behind the square here, is that to keep constant error post-Cholesky-factorization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed the x_max
term in the above comment. Using square seems to be more consistent w.r.t. to cases x.size(-1) > 1
- but I like your clamp by tiny
better.
Re x_max
: using global max
makes sense, but I feel that it might be better to use max
of rows instead, e.g. considering the diagonal matrix [0.0001, 10000]
, the global jitter is large w.r.t. the first diagonal term.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice idea! I've switched to using a row-wise max. This required increasing CHOLESKY_RELATIVE_JITTER
from 1.0 to 4.0, but this way still seems better 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this really preferred? this changes the eigenvalues and eigenvectors as opposed to jitter that is proportional to the identity (which only changes eigenvalues)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@martinjankowiak that's a good point, I didn't know about eigenvector preservation. I'd be ok with either version.
One thing I like about @fehiepsi's solution is that users can on their side rotate the system before performing Gaussian ops, e.g. I'm approximately diagonalizing via QR
evals, evecs = torch.linalg.eig(transition_matrix)
Q, R = torch.linalg.qr(evecs.real)
transition_matrix = Q.T @ transition_matrix @ Q
which shrinks my diagonal perturbations
- [5.45, 1.81, 0.86, 0.76] * eps * CHOLESKY_RELATIVE_JITTER
+ [4.91, 1.67, 0.21, 0.17] * eps * CHOLESKY_RELATIVE_JITTER
That's correct, I'm hoping we won't actually need to toggle it and I'm intending the global variable as an emergency switch in case I need to change something post-release in a production model and want to avoid monkey patching. Actually we ought to have some sort of standard interface for all of Pyro's global settings, similar to GPyTorch's settings. Mind if I attempt that in a follow-up PR #3152? Whichever of these PRs merges first, I'll be sure to add to the second PR a registration @settings.register("cholesky_relative_jitter", __name__, "CHOLESKY_RELATIVE_JITTER")
def _validate_jitter(value):
assert isinstance(value, (int, float))
assert 0 <= value EDIT done in #3152. |
Thanks for reviewing! |
This numerically stabilizes Gaussian parallel-scan operations for use in very long chains, say > 10000 items. This use case is important for heterogeneous-length batching, where a batch of sequences can be concatenated and operated on as a single chain. In this setting I was seeing Cholesky errors in both single and double precision.
The fix in this PR is to modify
pyro.ops.tensor_utils.cholesky()
to add a small adaptive amount of jitter to all computations. This appears to both pass all existing tests (which are quite strong, it doesn't work to add a non-adaptive jitter), and to make Gaussian parallel-scan filtering work for very long sequences. 🎉Tested