Skip to content

Commit 0947252

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
PositiveIndexKernel (#3047)
Summary: PositiveIndexKernel - a MultiTaskGP kernel that enforces positive correlation. Should probably be upstreamed into GPyTorch at some point. Also introduces priors on diagonal and off-diagonals separately, so that priors can be set on task correlation in a more intuititve fashion. Differential Revision: D84878629
1 parent 1625a23 commit 0947252

File tree

3 files changed

+351
-0
lines changed

3 files changed

+351
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from gpytorch.constraints import GreaterThan, Positive
9+
from gpytorch.kernels import IndexKernel, Kernel
10+
from gpytorch.priors import Prior
11+
12+
13+
class PositiveIndexKernel(IndexKernel):
14+
r"""
15+
A kernel for discrete indices with strictly positive correlations.
16+
17+
Similar to IndexKernel but ensures all off-diagonal correlations are positive
18+
by using a Cholesky-like parameterization with positive elements.
19+
20+
.. math::
21+
k(i, j) = \frac{(LL^T)_{i,j}}{(LL^T)_{t,t}}
22+
23+
where L is a lower triangular matrix with positive elements and t is the
24+
target_task_index.
25+
"""
26+
27+
def __init__(
28+
self,
29+
num_tasks: int,
30+
rank: int = 1,
31+
task_prior: Prior | None = None,
32+
diag_prior: Prior | None = None,
33+
normalize_covar_matrix: bool = False,
34+
target_task_index: int = 0,
35+
unit_scale_for_target: bool = True,
36+
**kwargs,
37+
):
38+
r"""A kernel for discrete indices with strictly positive correlations.
39+
40+
Args:
41+
num_tasks (int): Total number of indices.
42+
rank (int): Rank of the covariance matrix parameterization.
43+
task_prior (Prior, optional): Prior for the covariance matrix.
44+
diag_prior (Prior, optional): Prior for the diagonal elements.
45+
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
46+
target_task_index (int): Index of the task whose diagonal element should be
47+
normalized to 1. Defaults to 0 (first task).
48+
unit_scale_for_target (bool): Whether to ensure the target task's has unit
49+
outputscale.
50+
**kwargs: Additional arguments passed to IndexKernel.
51+
"""
52+
if rank > num_tasks:
53+
raise RuntimeError(
54+
"Cannot create a task covariance matrix larger than the number of tasks"
55+
)
56+
if not (0 <= target_task_index < num_tasks):
57+
raise ValueError(
58+
f"target_task_index must be between 0 and {num_tasks - 1}, "
59+
f"got {target_task_index}"
60+
)
61+
Kernel.__init__(**kwargs)
62+
63+
if var_constraint is None:
64+
var_constraint = Positive()
65+
66+
self.register_constraint("raw_var", var_constraint)
67+
# delete covar factor from parameters
68+
self.normalize_covar_matrix = normalize_covar_matrix
69+
self.num_tasks = num_tasks
70+
self.target_task_index = target_task_index
71+
self.register_parameter(
72+
name="raw_covar_factor",
73+
parameter=torch.nn.Parameter(
74+
torch.rand(*self.batch_shape, num_tasks, rank)
75+
),
76+
)
77+
self.unit_scale_for_target = unit_scale_for_target
78+
if task_prior is not None:
79+
if not isinstance(task_prior, Prior):
80+
raise TypeError(
81+
f"Expected gpytorch.priors.Prior but got "
82+
f"{type(task_prior).__name__}"
83+
)
84+
self.register_prior(
85+
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle
86+
)
87+
if diag_prior is not None:
88+
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)
89+
90+
self.register_constraint("raw_covar_factor", GreaterThan(0.0))
91+
92+
def _covar_factor_params(self, m):
93+
return m.covar_factor
94+
95+
def _covar_factor_closure(self, m, v):
96+
m._set_covar_factor(v)
97+
98+
@property
99+
def covar_factor(self):
100+
return self.raw_covar_factor_constraint.transform(self.raw_covar_factor)
101+
102+
@covar_factor.setter
103+
def covar_factor(self, value):
104+
self._set_covar_factor(value)
105+
106+
def _set_covar_factor(self, value):
107+
# This must be a tensor
108+
self.initialize(
109+
raw_covar_factor=self.raw_covar_factor_constraint.inverse_transform(value)
110+
)
111+
112+
@property
113+
def _lower_triangle(self):
114+
lower_row, lower_col = torch.tril_indices(
115+
self.num_tasks, self.num_tasks, offset=-1
116+
)
117+
covar = self.covar_matrix
118+
norm_factor = covar.diagonal(dim1=-1, dim2=-2).sqrt()
119+
corr = covar / (norm_factor.unsqueeze(-1) * norm_factor.unsqueeze(-2))
120+
low_tri = corr[..., lower_row, lower_col]
121+
122+
return low_tri
123+
124+
@property
125+
def _diagonal(self):
126+
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)
127+
128+
def _eval_covar_matrix(self):
129+
cf = self.covar_factor
130+
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
131+
self.num_tasks, dtype=cf.dtype, device=cf.device
132+
)
133+
# Normalize by the target task's diagonal element
134+
if self.unit_scale_for_target:
135+
norm_factor = covar[..., self.target_task_index, self.target_task_index]
136+
covar = covar / norm_factor.unsqueeze(-1).unsqueeze(-1)
137+
return covar
138+
139+
@property
140+
def covar_matrix(self):
141+
return self._eval_covar_matrix()

sphinx/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ Kernels
146146
.. automodule:: botorch.models.kernels.orthogonal_additive_kernel
147147
.. autoclass:: OrthogonalAdditiveKernel
148148

149+
.. automodule:: botorch.models.kernels.positive_index
150+
.. autoclass:: PositiveIndexKernel
151+
149152
Likelihoods
150153
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
151154
.. automodule:: botorch.models.likelihoods.pairwise
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.models.kernels.positive_index import PositiveIndexKernel
9+
from botorch.utils.testing import BotorchTestCase
10+
from gpytorch.priors import NormalPrior
11+
12+
13+
class TestPositiveIndexKernel(BotorchTestCase):
14+
def test_positive_index_kernel(self):
15+
# Test initialization
16+
with self.subTest("basic_initialization"):
17+
num_tasks = 4
18+
rank = 2
19+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
20+
21+
self.assertEqual(kernel.num_tasks, num_tasks)
22+
self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank))
23+
self.assertEqual(kernel.normalize_covar_matrix, False)
24+
25+
# Test initialization with batch shape
26+
with self.subTest("initialization_with_batch_shape"):
27+
num_tasks = 3
28+
rank = 2
29+
batch_shape = torch.Size([2])
30+
kernel = PositiveIndexKernel(
31+
num_tasks=num_tasks, rank=rank, batch_shape=batch_shape
32+
)
33+
34+
self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank))
35+
36+
# Test rank validation
37+
with self.subTest("rank_validation"):
38+
num_tasks = 3
39+
rank = 5
40+
with self.assertRaises(RuntimeError):
41+
PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
42+
43+
# Test target_task_index validation
44+
with self.subTest("target_task_index_validation"):
45+
num_tasks = 4
46+
# Test invalid negative index
47+
with self.assertRaises(ValueError):
48+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=-1)
49+
# Test invalid index >= num_tasks
50+
with self.assertRaises(ValueError):
51+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=4)
52+
# Test valid indices (should not raise)
53+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0)
54+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3)
55+
56+
# Test covar_factor constraint
57+
with self.subTest("positive_correlations"):
58+
kernel = PositiveIndexKernel(num_tasks=5, rank=3)
59+
covar_factor = kernel.covar_factor
60+
61+
# All elements should be positive
62+
self.assertTrue((covar_factor > 0).all())
63+
64+
self.assertTrue((kernel.covar_matrix >= 0).all())
65+
66+
# Test covariance matrix normalization (default target_task_index=0)
67+
with self.subTest("covar_matrix_normalization_default"):
68+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
69+
covar = kernel.covar_matrix
70+
71+
# First diagonal element should be 1.0 (normalized by default)
72+
self.assertAllClose(covar[0, 0], torch.tensor(1.0), atol=1e-4)
73+
74+
# Test covariance matrix normalization with custom target_task_index
75+
with self.subTest("covar_matrix_normalization_custom_target"):
76+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=2)
77+
covar = kernel.covar_matrix
78+
79+
# Third diagonal element should be 1.0 (target_task_index=2)
80+
self.assertAllClose(covar[2, 2], torch.tensor(1.0), atol=1e-4)
81+
82+
# Other diagonal elements should not be 1.0
83+
self.assertNotEqual(covar[0, 0].item(), 1.0)
84+
85+
# Test forward pass shape
86+
with self.subTest("forward"):
87+
num_tasks = 4
88+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
89+
kernel.eval()
90+
91+
i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)
92+
i2 = torch.tensor([[1, 2]], dtype=torch.long)
93+
94+
result = kernel(i1, i2)
95+
self.assertEqual(result.shape, torch.Size([2, 1]))
96+
num_tasks = 3
97+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1)
98+
kernel.eval()
99+
100+
kernel.initialize(raw_covar_factor=torch.ones(num_tasks, 1))
101+
i1 = torch.tensor([[0]], dtype=torch.long)
102+
i2 = torch.tensor([[1]], dtype=torch.long)
103+
104+
result = kernel(i1, i2).to_dense()
105+
covar_matrix = kernel.covar_matrix
106+
expected = covar_matrix[0, 1]
107+
108+
self.assertAllClose(result.squeeze(), expected)
109+
110+
# Test with priors
111+
with self.subTest("with_priors"):
112+
num_tasks = 4
113+
task_prior = NormalPrior(0, 1)
114+
diag_prior = NormalPrior(1, 0.1)
115+
116+
kernel = PositiveIndexKernel(
117+
num_tasks=num_tasks,
118+
rank=2,
119+
task_prior=task_prior,
120+
diag_prior=diag_prior,
121+
initialize_to_mode=False,
122+
)
123+
prior_names = [p[0] for p in kernel.named_priors()]
124+
self.assertIn("IndexKernelPrior", prior_names)
125+
self.assertIn("ScalePrior", prior_names)
126+
127+
# Test batch forward
128+
with self.subTest("batch_forward"):
129+
num_tasks = 3
130+
batch_shape = torch.Size([2])
131+
kernel = PositiveIndexKernel(
132+
num_tasks=num_tasks, rank=2, batch_shape=batch_shape
133+
)
134+
kernel.eval()
135+
136+
i1 = torch.tensor([[[0], [1]]], dtype=torch.long)
137+
i2 = torch.tensor([[[1], [2]]], dtype=torch.long)
138+
139+
result = kernel(i1, i2)
140+
141+
# Check that batch dimensions are preserved
142+
self.assertEqual(result.shape[0], 2)
143+
144+
# Test diagonal property (default target_task_index=0)
145+
with self.subTest("diagonal"):
146+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
147+
diag = kernel._diagonal
148+
149+
self.assertEqual(diag.shape, torch.Size([4]))
150+
# First diagonal element should be 1.0 (default target_task_index=0)
151+
self.assertAllClose(diag[0], torch.tensor(1.0), atol=1e-4)
152+
153+
# Test diagonal property with custom target_task_index
154+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=1)
155+
diag = kernel._diagonal
156+
157+
self.assertEqual(diag.shape, torch.Size([4]))
158+
# Second diagonal element should be 1.0 (target_task_index=1)
159+
self.assertAllClose(diag[1], torch.tensor(1.0), atol=1e-4)
160+
161+
# Test lower triangle property
162+
with self.subTest("lower_triangle"):
163+
num_tasks = 5
164+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
165+
lower_tri = kernel._lower_triangle
166+
167+
# Number of lower triangular elements (excluding diagonal)
168+
expected_size = num_tasks * (num_tasks - 1) // 2
169+
self.assertEqual(lower_tri.shape[-1], expected_size)
170+
self.assertTrue((lower_tri >= 0).all())
171+
172+
# Test invalid prior type
173+
with self.subTest("invalid_prior_type"):
174+
with self.assertRaises(TypeError):
175+
PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior")
176+
177+
# Test covariance matrix properties
178+
with self.subTest("covar_matrix"):
179+
kernel = PositiveIndexKernel(num_tasks=5, rank=4)
180+
covar = kernel.covar_matrix
181+
182+
# Should be square
183+
self.assertEqual(covar.shape[-2], covar.shape[-1])
184+
185+
# Should be positive definite (all eigenvalues > 0)
186+
eigvals = torch.linalg.eigvalsh(covar)
187+
self.assertTrue((eigvals > 0).all())
188+
189+
# Should be symmetric
190+
self.assertAllClose(covar, covar.T, atol=1e-5)
191+
192+
# Test covar_factor setter and getter
193+
with self.subTest("covar_factor"):
194+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
195+
new_covar_factor = torch.ones(3, 2) * 2.0
196+
kernel.covar_factor = new_covar_factor
197+
self.assertAllClose(kernel.covar_factor, new_covar_factor, atol=1e-5)
198+
199+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
200+
params = kernel._covar_factor_params(kernel)
201+
self.assertEqual(params.shape, torch.Size([3, 2]))
202+
self.assertTrue((params > 0).all())
203+
204+
kernel = PositiveIndexKernel(num_tasks=3, rank=2)
205+
new_value = torch.ones(3, 2) * 3.0
206+
kernel._covar_factor_closure(kernel, new_value)
207+
self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5)

0 commit comments

Comments
 (0)