-
-
Notifications
You must be signed in to change notification settings - Fork 987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Represent Gaussian using square root of precision #2019
base: dev
Are you sure you want to change the base?
Conversation
@fritzo The square root representation works smoothly and I believe that it will be more stable than the current representation. To make it easier to compare and test, I separated the implementation into a new file. In addition, I added the Here are several properties of this representation:
I believe this version will be a bit faster (faster |
Btw, just realize that this is PR #2019 ! :D |
…se these tests pass locally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @fehiepsi, I took a first pass through the code and it looks plausible. I have one major concern.
IIUC the sqrt form replaces cholesky solves with concatenation, which is numerically stable but which is not bounded in space. For example, consider the basic operation of the GaussianHMM
over T
time steps: we recursively call Gaussian.__add__()
and Gaussian.marginalize()
. Each op is called O(log2(T))
times. Consider the following identities
# Rank identities for GaussianS:
(g1 + g2).rank() == g1.rank() + g2.rank() # rank is additive.
g.marginalize(...).rank() == g.rank() # .marginalize() preserves rank.
Now in the released Gaussian
implementation of the GaussianHMM
with n
-dimensional state, rank always n
so total memory space is O(T * n^2)
. However in the proposed GaussianS
implementation of GaussianHMM
rank at recursive stage s
is n * 2^s
, leading to total memory space O(T * log2(T) * n^2)
, i.e. the proposed square root form would cost an extra factor of O(log2(T))
more space. Does this analysis appear correct to you?
if self.rank() == 0: | ||
empty = self.prec_sqrt.new_tensor([]) | ||
return empty.reshape(self.prec_sqrt.shape[:-1] + (self.dim(),)) | ||
return self.prec_sqrt.matmul(self.prec_sqrt.transpose(-2, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply
return self.prec_sqrt.new_zeros(self.batch_shape + (self.dim(), self.dim()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need to return empty tensor, not zeros tensor. I think it is used to pass tests for constant Gaussian.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.precision
should be the zeros tensor, not the empty tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I thought we are multiplying two empty tensor, which will return an empty tensor. :D
""" | ||
# TODO: benchmark two versions in real application, using 'qr' will be more stable but slower | ||
if method == 'cholesky': | ||
return A.matmul(A.transpose(-1, -2)).cholesky() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fritzo The operator A.matmul(A.T)
will remove all the expanded rank at each step. Assume that A
has shape n x 4n
(which is the case when using tensordot
), then A.matmul(A.T)
will have shape n x n
. I will add a test to guarantee that the square root's shape does not expanded in tensordot op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, QR of A.T where A.shape = n x 4n
has shape n x n
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But you're not currently triangularizing Psqrt_a
in .marginalize()
. I think to achieve bounded rank, you'll need to manually triangularize at the end of .marginalize()
:
if prec_sqrt.size(-1) > prec_sqrt.size(-2):
prec_sqrt = triangularize(prec_sqrt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! I missed this point. Let me see if there is an easy way to achieve this. :(
@fritzo Applying at the end of |
@fritzo It seems that to marginalize out
the marginalized precision is (B@Bt+C@Ct) - B@At@(A@At)^-1@A@Bt, which is the same as C@Ct if we assume that A is invertible (which is required to be true if we want to marginalize This is the main trick used in square root information filters, there only QR is needed. I can complete this work if this is needed by someone. Maybe we can also incorporate this fact into funsor. |
:param torch.Tensor prec_sqrt: square root of precision matrix of this gaussian. We use this | ||
representation to preserve symmetric and reduce condition numbers. | ||
""" | ||
def __init__(self, log_normalizer, info_vec, prec_sqrt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: switch to whiten_vec as in funsor. Probably renaming the attributes to sri_matrix
, sri_factor
(to match SRIF literature) and log_factor
This is an effort to resolve the Cholesky issue as reported in #2017.
There is some work during the transition to this representation but I hope that it will be smooth.
TODO:
marginalize
method to useprec_sqrt
matrix_and_mvn_to_gaussian
,mvn_to_gaussian
to this new representation. It is best to have a MVN distribution parameterized byprec_sqrt
instead of usingtorch.inverse(scale_tril)
, but we can defer this to future improvements.Implement unbroadcasted gaussian_tensordot (instead of using. Edit: this is not necessary, pad + add + marginalize would have the same performance.pad + add + marginalize
)[ ] Update to match recent changes in Gaussian.[ ] Add tests to see if this works for high dimensional latent hmm.