Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Represent Gaussian using square root of precision #2019

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Aug 19, 2019

This is an effort to resolve the Cholesky issue as reported in #2017.

There is some work during the transition to this representation but I hope that it will be smooth.

TODO:

  • Convert marginalize method to use prec_sqrt
  • Convert matrix_and_mvn_to_gaussian, mvn_to_gaussian to this new representation. It is best to have a MVN distribution parameterized by prec_sqrt instead of using torch.inverse(scale_tril), but we can defer this to future improvements.
  • Add tests for this new representation.
  • Implement unbroadcasted gaussian_tensordot (instead of using pad + add + marginalize). Edit: this is not necessary, pad + add + marginalize would have the same performance.
  • [ ] Update to match recent changes in Gaussian.
  • Add logic to deal with shape blow up which is pointed out by @fritzo
  • [ ] Add tests to see if this works for high dimensional latent hmm.
  • Merge with Gaussian class, dispatching as necessary

@fehiepsi
Copy link
Member Author

fehiepsi commented Aug 19, 2019

@fritzo The square root representation works smoothly and I believe that it will be more stable than the current representation. To make it easier to compare and test, I separated the implementation into a new file. In addition, I added the S notation to class and functions in this file to distinguish this implementation and the current implementation. Do you have other suggestion to name this new class and the corresponding functions.

Here are several properties of this representation:

  • event_permute will permute the row vectors of prec_sqrt matrix
  • add two gaussians will concatenate the column vectors of their prec_sqrt matrices
  • event_pad will add more zero rows to prec_sqrt matrix

I believe this version will be a bit faster (faster at mvn_to_gaussianS function) and more stable. It would be faster and more stable if use pyro.param('prec_sqrt', constraint=lower_cholesky) and use prec_sqrt instead of scale_tril in MVN implementation.

@fehiepsi
Copy link
Member Author

Btw, just realize that this is PR #2019 ! :D

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fehiepsi, I took a first pass through the code and it looks plausible. I have one major concern.

IIUC the sqrt form replaces cholesky solves with concatenation, which is numerically stable but which is not bounded in space. For example, consider the basic operation of the GaussianHMM over T time steps: we recursively call Gaussian.__add__() and Gaussian.marginalize(). Each op is called O(log2(T)) times. Consider the following identities

# Rank identities for GaussianS:
(g1 + g2).rank() == g1.rank() + g2.rank()  # rank is additive.
g.marginalize(...).rank() == g.rank()      # .marginalize() preserves rank.

Now in the released Gaussian implementation of the GaussianHMM with n-dimensional state, rank always n so total memory space is O(T * n^2). However in the proposed GaussianS implementation of GaussianHMM rank at recursive stage s is n * 2^s, leading to total memory space O(T * log2(T) * n^2), i.e. the proposed square root form would cost an extra factor of O(log2(T)) more space. Does this analysis appear correct to you?

if self.rank() == 0:
empty = self.prec_sqrt.new_tensor([])
return empty.reshape(self.prec_sqrt.shape[:-1] + (self.dim(),))
return self.prec_sqrt.matmul(self.prec_sqrt.transpose(-2, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply

return self.prec_sqrt.new_zeros(self.batch_shape + (self.dim(), self.dim()))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to return empty tensor, not zeros tensor. I think it is used to pass tests for constant Gaussian.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.precision should be the zeros tensor, not the empty tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I thought we are multiplying two empty tensor, which will return an empty tensor. :D

"""
# TODO: benchmark two versions in real application, using 'qr' will be more stable but slower
if method == 'cholesky':
return A.matmul(A.transpose(-1, -2)).cholesky()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fritzo The operator A.matmul(A.T) will remove all the expanded rank at each step. Assume that A has shape n x 4n (which is the case when using tensordot), then A.matmul(A.T) will have shape n x n. I will add a test to guarantee that the square root's shape does not expanded in tensordot op.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, QR of A.T where A.shape = n x 4n has shape n x n.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you're not currently triangularizing Psqrt_a in .marginalize(). I think to achieve bounded rank, you'll need to manually triangularize at the end of .marginalize():

if prec_sqrt.size(-1) > prec_sqrt.size(-2):
    prec_sqrt = triangularize(prec_sqrt)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! I missed this point. Let me see if there is an easy way to achieve this. :(

@fehiepsi
Copy link
Member Author

@fritzo Applying at the end of .marginalize() is correct but it seems that we have to use qr (which is more expensive than Cholesky) there because there is no requirement that the returned Gaussian has full rank. I want to consider other possibilities so I'll revisit this issue after numpyro 0.2 and GaussianGamma PR. Thanks for pointing out that rank issue!

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 17, 2023

@fritzo It seems that to marginalize out x in an (x,y) Gaussian pair, we can just simply triangularize the prec_sqrt part and keep the bottom right block. This way, no need for using cholesky during marginalization and the rank is also bounded. More precisely, given a square root matrix,

[[A 0]
 [B C]]

the marginalized precision is (B@Bt+C@Ct) - B@At@(A@At)^-1@A@Bt, which is the same as C@Ct if we assume that A is invertible (which is required to be true if we want to marginalize x out).

This is the main trick used in square root information filters, there only QR is needed.

I can complete this work if this is needed by someone. Maybe we can also incorporate this fact into funsor.

:param torch.Tensor prec_sqrt: square root of precision matrix of this gaussian. We use this
representation to preserve symmetric and reduce condition numbers.
"""
def __init__(self, log_normalizer, info_vec, prec_sqrt):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: switch to whiten_vec as in funsor. Probably renaming the attributes to sri_matrix, sri_factor (to match SRIF literature) and log_factor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants