Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions botorch/models/kernels/positive_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from gpytorch.constraints import GreaterThan, Interval, Positive
from gpytorch.kernels import IndexKernel, Kernel
from gpytorch.priors import Prior


class PositiveIndexKernel(IndexKernel):
r"""
A kernel for discrete indices with strictly positive correlations. This is
enforced by a positivity constraint on the decomposed covariance matrix.

Similar to IndexKernel but ensures all off-diagonal correlations are positive
by using a Cholesky-like parameterization with positive elements.

.. math::
k(i, j) = \frac{(LL^T)_{i,j}}{(LL^T)_{t,t}}

where L is a lower triangular matrix with positive elements and t is the
target_task_index.
"""

def __init__(
self,
num_tasks: int,
rank: int = 1,
task_prior: Prior | None = None,
diag_prior: Prior | None = None,
normalize_covar_matrix: bool = False,
var_constraint: Interval | None = None,
target_task_index: int = 0,
unit_scale_for_target: bool = True,
**kwargs,
):
r"""A kernel for discrete indices with strictly positive correlations.

Args:
num_tasks (int): Total number of indices.
rank (int): Rank of the covariance matrix parameterization.
task_prior (Prior, optional): Prior for the covariance matrix.
diag_prior (Prior, optional): Prior for the diagonal elements.
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
target_task_index (int): Index of the task whose diagonal element should be
normalized to 1. Defaults to 0 (first task).
unit_scale_for_target (bool): Whether to ensure the target task's has unit
outputscale.
**kwargs: Additional arguments passed to IndexKernel.
"""
if rank > num_tasks:
raise RuntimeError(
"Cannot create a task covariance matrix larger than the number of tasks"
)
if not (0 <= target_task_index < num_tasks):
raise ValueError(
f"target_task_index must be between 0 and {num_tasks - 1}, "
f"got {target_task_index}"
)
Kernel.__init__(self, **kwargs)

if var_constraint is None:
var_constraint = Positive()

self.register_parameter(
name="raw_var",
parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks)),
)
self.register_constraint("raw_var", var_constraint)
# delete covar factor from parameters
self.normalize_covar_matrix = normalize_covar_matrix
self.num_tasks = num_tasks
self.target_task_index = target_task_index
self.register_parameter(
name="raw_covar_factor",
parameter=torch.nn.Parameter(
torch.rand(*self.batch_shape, num_tasks, rank)
),
)
self.unit_scale_for_target = unit_scale_for_target
if task_prior is not None:
if not isinstance(task_prior, Prior):
raise TypeError(
f"Expected gpytorch.priors.Prior but got "
f"{type(task_prior).__name__}"
)
self.register_prior(
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle_corr
)
if diag_prior is not None:
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)

self.register_constraint("raw_covar_factor", GreaterThan(0.0))

def _covar_factor_params(self, m):
return m.covar_factor

def _covar_factor_closure(self, m, v):
m._set_covar_factor(v)

@property
def covar_factor(self):
return self.raw_covar_factor_constraint.transform(self.raw_covar_factor)

@covar_factor.setter
def covar_factor(self, value):
self._set_covar_factor(value)

def _set_covar_factor(self, value):
# This must be a tensor
self.initialize(
raw_covar_factor=self.raw_covar_factor_constraint.inverse_transform(value)
)

@property
def _lower_triangle_corr(self):
lower_row, lower_col = torch.tril_indices(
self.num_tasks, self.num_tasks, offset=-1
)
covar = self.covar_matrix
norm_factor = covar.diagonal(dim1=-1, dim2=-2).sqrt()
corr = covar / (norm_factor.unsqueeze(-1) * norm_factor.unsqueeze(-2))
low_tri = corr[..., lower_row, lower_col]

return low_tri

@property
def _diagonal(self):
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)

def _eval_covar_matrix(self):
cf = self.covar_factor
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
self.num_tasks, dtype=cf.dtype, device=cf.device
)
# Normalize by the target task's diagonal element
if self.unit_scale_for_target:
norm_factor = covar[..., self.target_task_index, self.target_task_index]
covar = covar / norm_factor.unsqueeze(-1).unsqueeze(-1)
return covar

@property
def covar_matrix(self):
return self._eval_covar_matrix()
3 changes: 3 additions & 0 deletions sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ Kernels
.. automodule:: botorch.models.kernels.orthogonal_additive_kernel
.. autoclass:: OrthogonalAdditiveKernel

.. automodule:: botorch.models.kernels.positive_index
.. autoclass:: PositiveIndexKernel

Likelihoods
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.likelihoods.pairwise
Expand Down
224 changes: 224 additions & 0 deletions test/models/kernels/test_positive_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.models.kernels.positive_index import PositiveIndexKernel
from botorch.utils.testing import BotorchTestCase
from gpytorch.priors import NormalPrior


class TestPositiveIndexKernel(BotorchTestCase):
def test_positive_index_kernel(self):
for dtype in (torch.float32, torch.float64):
# Test initialization
with self.subTest("basic_initialization", dtype=dtype):
num_tasks = 4
rank = 2
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank).to(
dtype=dtype
)

self.assertEqual(kernel.num_tasks, num_tasks)
self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank))
self.assertEqual(kernel.normalize_covar_matrix, False)

# Test initialization with batch shape
with self.subTest("initialization_with_batch_shape", dtype=dtype):
num_tasks = 3
rank = 2
batch_shape = torch.Size([2])
kernel = PositiveIndexKernel(
num_tasks=num_tasks, rank=rank, batch_shape=batch_shape
).to(dtype=dtype)

self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank))

# Test rank validation
with self.subTest("rank_validation", dtype=dtype):
num_tasks = 3
rank = 5
with self.assertRaises(RuntimeError):
PositiveIndexKernel(num_tasks=num_tasks, rank=rank)

# Test target_task_index validation
with self.subTest("target_task_index_validation", dtype=dtype):
num_tasks = 4
# Test invalid negative index
with self.assertRaises(ValueError):
PositiveIndexKernel(
num_tasks=num_tasks, rank=2, target_task_index=-1
)
# Test invalid index >= num_tasks
with self.assertRaises(ValueError):
PositiveIndexKernel(
num_tasks=num_tasks, rank=2, target_task_index=4
)
# Test valid indices (should not raise)
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0)
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3)

# Test covar_factor constraint
with self.subTest("positive_correlations", dtype=dtype):
kernel = PositiveIndexKernel(num_tasks=5, rank=3).to(dtype=dtype)
covar_factor = kernel.covar_factor

# All elements should be positive
self.assertTrue((covar_factor > 0).all())

self.assertTrue((kernel.covar_matrix >= 0).all())

# Test covariance matrix normalization (default target_task_index=0)
with self.subTest("covar_matrix_normalization_default", dtype=dtype):
kernel = PositiveIndexKernel(num_tasks=4, rank=2).to(dtype=dtype)
covar = kernel.covar_matrix

# First diagonal element should be 1.0 (normalized by default)
self.assertAllClose(
covar[0, 0], torch.tensor(1.0, dtype=dtype), atol=1e-4
)

# Test covariance matrix normalization with custom target_task_index
with self.subTest("covar_matrix_normalization_custom_target", dtype=dtype):
kernel = PositiveIndexKernel(
num_tasks=4, rank=2, target_task_index=2
).to(dtype=dtype)
covar = kernel.covar_matrix

# Third diagonal element should be 1.0 (target_task_index=2)
self.assertAllClose(
covar[2, 2], torch.tensor(1.0, dtype=dtype), atol=1e-4
)

# Test forward pass shape
with self.subTest("forward", dtype=dtype):
num_tasks = 4
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2).to(
dtype=dtype
)

i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)
i2 = torch.tensor([[1, 2]], dtype=torch.long)

result = kernel(i1, i2)
self.assertEqual(result.shape, torch.Size([2, 1]))
num_tasks = 3
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1).to(
dtype=dtype
)

kernel.initialize(
raw_covar_factor=torch.ones(num_tasks, 1, dtype=dtype)
)
i1 = torch.tensor([[0]], dtype=torch.long)
i2 = torch.tensor([[1]], dtype=torch.long)

result = kernel(i1, i2).to_dense()
covar_matrix = kernel.covar_matrix
expected = covar_matrix[0, 1]

self.assertAllClose(result.squeeze(), expected)

# Test with priors
with self.subTest("with_priors", dtype=dtype):
num_tasks = 4
task_prior = NormalPrior(0, 1)
diag_prior = NormalPrior(1, 0.1)

kernel = PositiveIndexKernel(
num_tasks=num_tasks,
rank=2,
task_prior=task_prior,
diag_prior=diag_prior,
initialize_to_mode=False,
).to(dtype=dtype)
prior_names = [p[0] for p in kernel.named_priors()]
self.assertIn("IndexKernelPrior", prior_names)
self.assertIn("ScalePrior", prior_names)

# Test batch forward
with self.subTest("batch_forward", dtype=dtype):
num_tasks = 3
batch_shape = torch.Size([2])
kernel = PositiveIndexKernel(
num_tasks=num_tasks, rank=2, batch_shape=batch_shape
).to(dtype=dtype)

i1 = torch.tensor([[[0], [1]]], dtype=torch.long)
i2 = torch.tensor([[[1], [2]]], dtype=torch.long)

result = kernel(i1, i2)

# Check that batch dimensions are preserved
self.assertEqual(result.shape[0], 2)

# Test diagonal property (default target_task_index=0)
with self.subTest("diagonal", dtype=dtype):
kernel = PositiveIndexKernel(num_tasks=4, rank=2).to(dtype=dtype)
diag = kernel._diagonal

self.assertEqual(diag.shape, torch.Size([4]))
# First diagonal element should be 1.0 (default target_task_index=0)
self.assertAllClose(diag[0], torch.tensor(1.0, dtype=dtype), atol=1e-4)

# Test diagonal property with custom target_task_index
kernel = PositiveIndexKernel(
num_tasks=4, rank=2, target_task_index=1
).to(dtype=dtype)
diag = kernel._diagonal

self.assertEqual(diag.shape, torch.Size([4]))
# Second diagonal element should be 1.0 (target_task_index=1)
self.assertAllClose(diag[1], torch.tensor(1.0, dtype=dtype), atol=1e-4)

# Test lower triangle property
with self.subTest("lower_triangle", dtype=dtype):
num_tasks = 5
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2).to(
dtype=dtype
)
lower_tri = kernel._lower_triangle_corr

# Number of lower triangular elements (excluding diagonal)
expected_size = num_tasks * (num_tasks - 1) // 2
self.assertEqual(lower_tri.shape[-1], expected_size)
self.assertTrue((lower_tri >= 0).all())

# Test invalid prior type
with self.subTest("invalid_prior_type", dtype=dtype):
with self.assertRaises(TypeError):
PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior")

# Test covariance matrix properties
with self.subTest("covar_matrix", dtype=dtype):
kernel = PositiveIndexKernel(num_tasks=5, rank=4).to(dtype=dtype)
covar = kernel.covar_matrix

# Should be square
self.assertEqual(covar.shape[-2], covar.shape[-1])

# Should be positive definite (all eigenvalues > 0)
eigvals = torch.linalg.eigvalsh(covar)
self.assertTrue((eigvals > 0).all())

# Should be symmetric
self.assertAllClose(covar, covar.T, atol=1e-5)

# Test covar_factor setter and getter
with self.subTest("covar_factor", dtype=dtype):
kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
new_covar_factor = torch.ones(3, 2, dtype=dtype) * 2.0
kernel.covar_factor = new_covar_factor
self.assertAllClose(kernel.covar_factor, new_covar_factor, atol=1e-5)

kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
params = kernel._covar_factor_params(kernel)
self.assertEqual(params.shape, torch.Size([3, 2]))
self.assertTrue((params > 0).all())

kernel = PositiveIndexKernel(num_tasks=3, rank=2).to(dtype=dtype)
new_value = torch.ones(3, 2, dtype=dtype) * 3.0
kernel._covar_factor_closure(kernel, new_value)
self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5)