Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jitter to Cholesky factorization in Gaussian ops #3151

Merged
merged 11 commits into from
Oct 30, 2022
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 28, 2022

This numerically stabilizes Gaussian parallel-scan operations for use in very long chains, say > 10000 items. This use case is important for heterogeneous-length batching, where a batch of sequences can be concatenated and operated on as a single chain. In this setting I was seeing Cholesky errors in both single and double precision.

The fix in this PR is to modify pyro.ops.tensor_utils.cholesky() to add a small adaptive amount of jitter to all computations. This appears to both pass all existing tests (which are quite strong, it doesn't work to add a non-adaptive jitter), and to make Gaussian parallel-scan filtering work for very long sequences. 🎉

Tested

  • added tests for very long sequences
  • tried on a real-world example of length up to 1e6

return x.sqrt()

# Add adaptive jitter.
x = x.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you sure you need/want clones? fwiw i do this in millipede, which is similar to what's done in gpytorch

Copy link
Member Author

@fritzo fritzo Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup this .clone() is needed because we're mutating the matrix. Nice, I'll rename this to safe_cholesky() as in millipede.

Thanks, I did try various gpytorch-style lazy tactics that avoid adding jitter until a failure has occurred. I found that since the Cholesky error happens only late in the filtering process, by that point the filter state had already been corrupted by nearly-singular matrices that just barely didn't trigger an error. The best solution I've found so far is to add a tiny amount of noise to all matrices so that error doesn't build up during filtering. Other solutions include using svd or pinv or ldl_factor, but they were more expensive.

Note the core piece of linear algebra is in Gaussian.marginalize() which is repeatedly called in the filter pass of sequential_gaussian_filter_sample(). It's just the blockwise symmetric matrix inverse formula:

# in Gaussian.marginalize():
P_aa = self.precision[..., a, a]
P_ba = self.precision[..., b, a]
P_bb = self.precision[..., b, b]
P_b = safe_cholesky(P_bb)                       # Note if we add a little jitter here...
P_a = triangular_solve(P_ba, P_b, upper=False)  # ...then this is smaller...
P_at = P_a.transpose(-1, -2)
precision = P_aa - matmul(P_at, P_a)            # ...so this is even better conditioned.

This code has the nice property that if we add a little bit of jitter before Cholesky factorizing, the next precision matrix becomes only better-conditioned. Empirically this allowed me to get away with much smaller jitter than was needed if I waited for an error to occur.

BTW it looks like you could speed up millipede by switching from try: c = cholesky() to the faster c, info = cholesky_ex(); if not info.any(): return c, which is used in gpytorch. The only reason I'm not using cholesky_ex() here is that I found the decision-based version was too unstable.

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. but CHOLESKY_JITTER isn't actually easily toggle-able is it?

fehiepsi
fehiepsi previously approved these changes Oct 28, 2022
pyro/ops/tensor_utils.py Outdated Show resolved Hide resolved
if x.size(-1) == 1:
if CHOLESKY_JITTER:
Copy link
Member

@fehiepsi fehiepsi Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intend to clamp by (CHOLESKY_JITTER * finfo(x.dtype).eps) ** 2 here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My intention was to scale by about finfo(x.dtype).eps * x.max() so that the jitter was just barely detectable by the largest matrix entry before Cholesky factorizing. That way if we set RELATIVE_CHOLESKY_JITTER = 1/2 then jitter will only affect matrix entries less than half the size of the max. And it kindof makes sense to me that each additional bit of precision would mean we would need to add half as much jitter, thus jitter would be proportional to finfo(x.dtype).eps. Mostly the proportional think helps us keep a constant RELATIVE_CHOLESKY_JITTER across float32 and float64.

What's your intuition behind the square here, is that to keep constant error post-Cholesky-factorization?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed the x_max term in the above comment. Using square seems to be more consistent w.r.t. to cases x.size(-1) > 1 - but I like your clamp by tiny better.

Re x_max: using global max makes sense, but I feel that it might be better to use max of rows instead, e.g. considering the diagonal matrix [0.0001, 10000], the global jitter is large w.r.t. the first diagonal term.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice idea! I've switched to using a row-wise max. This required increasing CHOLESKY_RELATIVE_JITTER from 1.0 to 4.0, but this way still seems better 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really preferred? this changes the eigenvalues and eigenvectors as opposed to jitter that is proportional to the identity (which only changes eigenvalues)

Copy link
Member Author

@fritzo fritzo Oct 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martinjankowiak that's a good point, I didn't know about eigenvector preservation. I'd be ok with either version.

One thing I like about @fehiepsi's solution is that users can on their side rotate the system before performing Gaussian ops, e.g. I'm approximately diagonalizing via QR

evals, evecs = torch.linalg.eig(transition_matrix)
Q, R = torch.linalg.qr(evecs.real)
transition_matrix = Q.T @ transition_matrix @ Q

which shrinks my diagonal perturbations

- [5.45, 1.81, 0.86, 0.76] * eps * CHOLESKY_RELATIVE_JITTER
+ [4.91, 1.67, 0.21, 0.17] * eps * CHOLESKY_RELATIVE_JITTER

@fritzo
Copy link
Member Author

fritzo commented Oct 30, 2022

@martinjankowiak CHOLESKY_JITTER isn't actually easily toggle-able is it?

That's correct, I'm hoping we won't actually need to toggle it and I'm intending the global variable as an emergency switch in case I need to change something post-release in a production model and want to avoid monkey patching.

Actually we ought to have some sort of standard interface for all of Pyro's global settings, similar to GPyTorch's settings. Mind if I attempt that in a follow-up PR #3152? Whichever of these PRs merges first, I'll be sure to add to the second PR a registration

@settings.register("cholesky_relative_jitter", __name__, "CHOLESKY_RELATIVE_JITTER")
def _validate_jitter(value):
    assert isinstance(value, (int, float))
    assert 0 <= value

EDIT done in #3152.

@fritzo fritzo dismissed stale reviews from fehiepsi and martinjankowiak via 152b913 October 30, 2022 02:20
@fehiepsi fehiepsi merged commit ed54fe8 into dev Oct 30, 2022
@fritzo fritzo deleted the gaussian-jitter branch October 30, 2022 16:39
@fritzo
Copy link
Member Author

fritzo commented Oct 30, 2022

Thanks for reviewing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants