Skip to content

Avoid dense matrices in chol_cap_mat for ConstantDiagLinearOperator A#120

Open
Osburg wants to merge 1 commit intocornellius-gp:mainfrom
Osburg:main
Open

Avoid dense matrices in chol_cap_mat for ConstantDiagLinearOperator A#120
Osburg wants to merge 1 commit intocornellius-gp:mainfrom
Osburg:main

Conversation

@Osburg
Copy link

@Osburg Osburg commented Jan 13, 2026

Hey :)
LowRankRootAddedDiagLinearOperator.chol_cap_mat() performs the matrix operation $C + V^T D^{-1} V$ where $D$ is a diagonal matrix. In a setting as described here this can lead to the formation of a large dense representation of $D$. In the special case of $D$ being a ConstantDiagLinearOperator $D=\sigma I$ we can avoid this by using the equivalent expression $C + \sigma^{-1} V^T V$. This PR implements this special case.

Cheers
Aaron :)

sigma_inv = A_inv.diag_values[0]
cap_mat = to_dense(C + sigma_inv * V.matmul(U))
else:
cap_mat = to_dense(C + V.matmul(A_inv.matmul(U)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that A_inv here is a DiagLinearOperator. Thus, A_inv.matmul(U) basically scales each row of U by corresponding entries in A_inv.

@kayween
Copy link
Collaborator

kayween commented Jan 13, 2026

@Osburg I don't think chol_cap_mat creates a dense representation of the diagonal matrix $D$? See the inline comments. So I feel like this PR would only bring negligible speed-up/memory reduction.

@Osburg
Copy link
Author

Osburg commented Jan 14, 2026

Hi @kayween,
Thx for the quick answer and also for the help with my previous question!! I just double checked, you are right in case I am not using a custom linear operator class.

If I try to do operations including my own SparseLinearOperator class for a very large, but sparse tensor $L \in \mathbb{R}^{N,n}, N>>n$ (that I must represent as sparse tensor since otherwise it won't fit into memory), however, I think I observe the described behavior (maybe I just implemented the SparseLinarOperator stupidly and this is where the problem lies - in that case I think this PR can be closed. Sry, I am still relatively new to GPs).

The setting is that I want to train a GP with covariance $LKL^T + \sigma I$, where $L \in \mathbb{R}^{N,n}$ is a sparse matrix and $K \in \mathbb{R}^{n,n}$ is the covariance matrix of a spatial kernel.
I can factorize $K$ to end up with $CC^T + \sigma I$ where $C=LK^{1/2}$ is a MatmulLinearOperator. Then I can make use of the efficient implementations of LowRankRootAddedDiagLinearOperator. Here is a small example that will lead to creating a large dense diagonal matrix:

from linear_operator import LinearOperator
from linear_operator.operators import LowRankRootLinearOperator
import numpy as np
import torch
import gpytorch
from jaxtyping import Float
from typing import Union
from torch import Tensor
torch.set_default_device("cuda:0")

class SparseLinearOperator(LinearOperator):

    def __init__(self, tsr):
        super().__init__(tsr)
        self.tensor = tsr

    def _matmul(
            self: Float[LinearOperator, "*batch M N"],
            rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
    ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
        return torch.sparse.mm(self.tensor, rhs)

    def _size(self) -> torch.Size:
        return self.tensor.size()

    def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
        return SparseLinearOperator(self.tensor.transpose(-2, -1))

    def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
        return self.tensor # not actually returning a dense but sparse tensor since it would not fit in memory



m = 50
M = m * m * m

# create coordinates
grid_bounds = [(-5, 5), (-5, 5), (-5, 5)]
grid_size = m
grid = torch.zeros(grid_size, len(grid_bounds))
for i in range(len(grid_bounds)):
    grid[:, i] = torch.linspace(grid_bounds[i][0], grid_bounds[i][1], grid_size)

x = gpytorch.utils.grid.create_data_from_grid(grid)
y = torch.randn_like(x[:,0])

class GPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.GridKernel(gpytorch.kernels.RBFKernel(nu=1.5),grid=grid)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)

        covar = covar_x.root_decomposition().root
        L = SparseLinearOperator(
            torch.sparse_coo_tensor(
                indices=torch.stack([torch.arange(0, M), torch.arange(0, M)], dim=0),
                values=torch.ones(M, dtype=torch.get_default_dtype()),
                size=(M, M), dtype=torch.get_default_dtype()
            )
        )
        covar = L @ covar
        covar_x = LowRankRootLinearOperator(covar)

        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPModel(x, y, likelihood)

training_iterations = 2

model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iterations):
    optimizer.zero_grad()
    output = model(x)
    loss = -mll(output, y)
    loss.backward()
    optimizer.step()

Exiting with the error

Traceback (most recent call last):
  File ... in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-c04c6cc31321>", line 83, in <module>
    loss = -mll(output, y)
            ^^^^^^^^^^^^^^
  File ... in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in forward
    res = output.log_prob(target)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File ... in log_prob
    inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in inv_quad_logdet
    self_inv_rhs = self._solve(inv_quad_rhs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in _solve
    chol_cap_mat = self.chol_cap_mat
                   ^^^^^^^^^^^^^^^^^
  File ... in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in chol_cap_mat
    cap_mat = to_dense(C + V.matmul(A_inv.matmul(U)))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in to_dense
    return obj.to_dense()
           ^^^^^^^^^^^^^^
  File ... in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in to_dense
    return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in <genexpr>
    return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()
                ^^^^^^^^^^^^^^^^^^^^
  File ... in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in to_dense
    return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense())
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... line 133, in to_dense
    return torch.matmul(self.left_linear_op.to_dense(), self.right_linear_op.to_dense())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in to_dense
    return torch.diag_embed(self._diag)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ... in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 58.21 GiB. GPU 

Cheers
Aaron :)

@kayween
Copy link
Collaborator

kayween commented Jan 14, 2026

Your implementation makes sense to me. This is a tricky use case as it involves sparsity. I think the main challenge here is that L is so large that we cannot even form the product $L K^{\frac12}$ explicitly.

For simplicity, let's say we want to invert $I + \big(LK^{\frac12}\big) \big(K^{\frac12} L^\top\big)$. The chol_cap_mat method will do Cholesky decomposition on this matrix $I + \big(K^{\frac12} L^\top\big) \big(LK^{\frac12}\big)$. In particuar, linear operator will make $LK^{\frac12}$ dense, which seems to trigger the OOM.

I do notice that your sparse linear operator's to_dense returns a sparse tensor. But it does not solve the problem here. Note that $LK^{\frac12}$ is a MatmulLinearOperator. When we call MatmulLinearOperator.to_dense, it's going to call to_dense on $L$ and $K^{\frac12}$ and then multiply them together. It's the multiplication that triggers OOM.

So I think there are two ways going forward.

  1. Hack into chol_cap_mat and regroup the multiplication. In particular, you will need to group the two $L$'s together like this $I + K^{\frac12} \big(L^\top L\big) K^{\frac12}$. I am not sure how well PyTorch supports matmul between two sparse COO tensors though.
  2. The other way to work around this is to use conjugate gradient for GP training/inference. That way, chol_cap_mat should not be called at all. In fact, I am a bit suprised that CG is not invoked in your code because I thought CG is dispatched automatically for large kernel matrices. Maybe you had some settings that disabled CG?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants