Skip to content

Commit 7dff851

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
PositiveIndexKernel (#3047)
Summary: PositiveIndexKernel - a MultiTaskGP kernel that enforces positive correlation. Should probably be upstreamed into GPyTorch at some point. Also introduces priors on diagonal and off-diagonals separately, so that priors can be set on task correlation in a more intuititve fashion. Differential Revision: D84878629
1 parent 09502f9 commit 7dff851

File tree

2 files changed

+340
-0
lines changed

2 files changed

+340
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
from gpytorch.constraints import GreaterThan
11+
from gpytorch.kernels import IndexKernel
12+
from gpytorch.priors import Prior
13+
14+
15+
class PositiveIndexKernel(IndexKernel):
16+
r"""
17+
A kernel for discrete indices with strictly positive correlations.
18+
19+
Similar to IndexKernel but ensures all off-diagonal correlations are positive
20+
by using a Cholesky-like parameterization with positive elements.
21+
22+
.. math::
23+
k(i, j) = \frac{(LL^T)_{i,j}}{(LL^T)_{t,t}}
24+
25+
where L is a lower triangular matrix with positive elements and t is the
26+
target_task_index.
27+
28+
Args:
29+
num_tasks (int): Total number of indices.
30+
rank (int): Rank of the covariance matrix parameterization.
31+
task_prior (Prior, optional): Prior for the covariance matrix.
32+
diag_prior (Prior, optional): Prior for the diagonal elements.
33+
var_constraint (Interval, optional): Constraint for variance (not used, kept for API compatibility).
34+
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
35+
target_task_index (int): Index of the task whose diagonal element should be
36+
normalized to 1. Defaults to 0 (first task).
37+
unit_scale_for_target (bool): Whether to ensure the target task's has unit otuputscale.
38+
**kwargs: Additional arguments passed to IndexKernel.
39+
"""
40+
41+
def __init__(
42+
self,
43+
num_tasks: int,
44+
rank: Optional[int] = 1,
45+
task_prior: Optional[Prior] = None,
46+
diag_prior: Optional[Prior] = None,
47+
normalize_covar_matrix: bool = False,
48+
target_task_index: int = 0,
49+
unit_scale_for_target: bool = True,
50+
**kwargs,
51+
):
52+
if rank > num_tasks:
53+
raise RuntimeError(
54+
"Cannot create a task covariance matrix larger than the number of tasks"
55+
)
56+
if not (0 <= target_task_index < num_tasks):
57+
raise ValueError(
58+
f"target_task_index must be between 0 and {num_tasks - 1}, "
59+
f"got {target_task_index}"
60+
)
61+
super().__init__(
62+
num_tasks=num_tasks,
63+
rank=rank,
64+
prior=task_prior,
65+
var_constraint=None,
66+
**kwargs,
67+
)
68+
self.normalize_covar_matrix = normalize_covar_matrix
69+
self.num_tasks = num_tasks
70+
self.target_task_index = target_task_index
71+
self.register_parameter(
72+
name="raw_covar_factor",
73+
parameter=torch.nn.Parameter(
74+
torch.rand(*self.batch_shape, num_tasks, rank)
75+
),
76+
)
77+
self.unit_scale_for_target = unit_scale_for_target
78+
if task_prior is not None:
79+
if not isinstance(task_prior, Prior):
80+
raise TypeError(
81+
f"Expected gpytorch.priors.Prior but got "
82+
f"{type(task_prior).__name__}"
83+
)
84+
self.register_prior(
85+
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle
86+
)
87+
if diag_prior is not None:
88+
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)
89+
90+
self.register_constraint("raw_covar_factor", GreaterThan(0.0))
91+
92+
def _covar_factor_params(self, m):
93+
return m.covar_factor
94+
95+
def _covar_factor_closure(self, m, v):
96+
m._set_covar_factor(v)
97+
98+
@property
99+
def covar_factor(self):
100+
return self.raw_covar_factor_constraint.transform(self.raw_covar_factor)
101+
102+
@covar_factor.setter
103+
def covar_factor(self, value):
104+
self._set_covar_factor(value)
105+
106+
def _set_covar_factor(self, value):
107+
# This must be a tensor
108+
self.initialize(
109+
raw_covar_factor=self.raw_covar_factor_constraint.inverse_transform(value)
110+
)
111+
112+
@property
113+
def _lower_triangle(self):
114+
lower_row, lower_col = torch.tril_indices(
115+
self.num_tasks, self.num_tasks, offset=-1
116+
)
117+
covar = self.covar_matrix
118+
norm_factor = covar.diagonal(dim1=-1, dim2=-2).sqrt()
119+
corr = covar / (norm_factor.unsqueeze(-1) * norm_factor.unsqueeze(-2))
120+
low_tri = corr[..., lower_row, lower_col]
121+
122+
return low_tri
123+
124+
@property
125+
def _diagonal(self):
126+
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)
127+
128+
def _eval_covar_matrix(self):
129+
cf = self.covar_factor
130+
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
131+
self.num_tasks, dtype=cf.dtype, device=cf.device
132+
)
133+
# Normalize by the target task's diagonal element
134+
if self.unit_scale_for_target:
135+
norm_factor = covar[..., self.target_task_index, self.target_task_index]
136+
covar = covar / norm_factor.unsqueeze(-1).unsqueeze(-1)
137+
return covar
138+
139+
@property
140+
def covar_matrix(self):
141+
return self._eval_covar_matrix()
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.models.kernels.positive_index import PositiveIndexKernel
9+
from botorch.utils.testing import BotorchTestCase
10+
from gpytorch.priors import NormalPrior
11+
12+
13+
class TestPositiveIndexKernel(BotorchTestCase):
14+
def test_positive_index_kernel(self):
15+
"""Comprehensive test for PositiveIndexKernel functionality."""
16+
17+
# Test initialization
18+
with self.subTest("basic_initialization"):
19+
num_tasks = 4
20+
rank = 2
21+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
22+
23+
self.assertEqual(kernel.num_tasks, num_tasks)
24+
self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank))
25+
self.assertEqual(kernel.normalize_covar_matrix, False)
26+
27+
# Test initialization with batch shape
28+
with self.subTest("initialization_with_batch_shape"):
29+
num_tasks = 3
30+
rank = 2
31+
batch_shape = torch.Size([2])
32+
kernel = PositiveIndexKernel(
33+
num_tasks=num_tasks, rank=rank, batch_shape=batch_shape
34+
)
35+
36+
self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank))
37+
38+
# Test rank validation
39+
with self.subTest("rank_validation"):
40+
num_tasks = 3
41+
rank = 5
42+
with self.assertRaises(RuntimeError):
43+
PositiveIndexKernel(num_tasks=num_tasks, rank=rank)
44+
45+
# Test target_task_index validation
46+
with self.subTest("target_task_index_validation"):
47+
num_tasks = 4
48+
# Test invalid negative index
49+
with self.assertRaises(ValueError):
50+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=-1)
51+
# Test invalid index >= num_tasks
52+
with self.assertRaises(ValueError):
53+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=4)
54+
# Test valid indices (should not raise)
55+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0)
56+
PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3)
57+
58+
# Test covar_factor constraint
59+
with self.subTest("positive_correlations"):
60+
kernel = PositiveIndexKernel(num_tasks=5, rank=3)
61+
covar_factor = kernel.covar_factor
62+
63+
# All elements should be positive
64+
self.assertTrue((covar_factor > 0).all())
65+
66+
self.assertTrue((kernel.covar_matrix >= 0).all())
67+
68+
# Test covariance matrix normalization (default target_task_index=0)
69+
with self.subTest("covar_matrix_normalization_default"):
70+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
71+
covar = kernel.covar_matrix
72+
73+
# First diagonal element should be 1.0 (normalized by default)
74+
self.assertAllClose(covar[0, 0], torch.tensor(1.0), atol=1e-4)
75+
76+
# Test covariance matrix normalization with custom target_task_index
77+
with self.subTest("covar_matrix_normalization_custom_target"):
78+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=2)
79+
covar = kernel.covar_matrix
80+
81+
# Third diagonal element should be 1.0 (target_task_index=2)
82+
self.assertAllClose(covar[2, 2], torch.tensor(1.0), atol=1e-4)
83+
84+
# Other diagonal elements should not be 1.0
85+
self.assertNotEqual(covar[0, 0].item(), 1.0)
86+
87+
# Test forward pass shape
88+
with self.subTest("forward"):
89+
num_tasks = 4
90+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
91+
kernel.eval()
92+
93+
# Create index inputs
94+
i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)
95+
i2 = torch.tensor([[1, 2]], dtype=torch.long)
96+
97+
result = kernel(i1, i2)
98+
99+
# Expected shape: (2, 2, 1, 1)
100+
self.assertEqual(result.shape, torch.Size([2, 1]))
101+
num_tasks = 3
102+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1)
103+
kernel.eval()
104+
105+
# Set specific covar_factor values for predictable output
106+
kernel.initialize(raw_covar_factor=torch.ones(num_tasks, 1))
107+
108+
i1 = torch.tensor([[0]], dtype=torch.long)
109+
i2 = torch.tensor([[1]], dtype=torch.long)
110+
111+
result = kernel(i1, i2).to_dense()
112+
covar_matrix = kernel.covar_matrix
113+
expected = covar_matrix[0, 1]
114+
115+
self.assertAllClose(result.squeeze(), expected)
116+
117+
# Test with priors
118+
with self.subTest("with_priors"):
119+
num_tasks = 4
120+
task_prior = NormalPrior(0, 1)
121+
diag_prior = NormalPrior(1, 0.1)
122+
123+
kernel = PositiveIndexKernel(
124+
num_tasks=num_tasks,
125+
rank=2,
126+
task_prior=task_prior,
127+
diag_prior=diag_prior,
128+
initialize_to_mode=False,
129+
)
130+
131+
# Check that priors are registered
132+
prior_names = [p[0] for p in kernel.named_priors()]
133+
self.assertIn("IndexKernelPrior", prior_names)
134+
self.assertIn("ScalePrior", prior_names)
135+
136+
# Test batch forward
137+
with self.subTest("batch_forward"):
138+
num_tasks = 3
139+
batch_shape = torch.Size([2])
140+
kernel = PositiveIndexKernel(
141+
num_tasks=num_tasks, rank=2, batch_shape=batch_shape
142+
)
143+
kernel.eval()
144+
145+
i1 = torch.tensor([[[0], [1]]], dtype=torch.long)
146+
i2 = torch.tensor([[[1], [2]]], dtype=torch.long)
147+
148+
result = kernel(i1, i2)
149+
150+
# Check that batch dimensions are preserved
151+
self.assertEqual(result.shape[0], 2)
152+
153+
# Test diagonal property (default target_task_index=0)
154+
with self.subTest("diagonal"):
155+
kernel = PositiveIndexKernel(num_tasks=4, rank=2)
156+
diag = kernel._diagonal
157+
158+
self.assertEqual(diag.shape, torch.Size([4]))
159+
# First diagonal element should be 1.0 (default target_task_index=0)
160+
self.assertAllClose(diag[0], torch.tensor(1.0), atol=1e-4)
161+
162+
# Test diagonal property with custom target_task_index
163+
kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=1)
164+
diag = kernel._diagonal
165+
166+
self.assertEqual(diag.shape, torch.Size([4]))
167+
# Second diagonal element should be 1.0 (target_task_index=1)
168+
self.assertAllClose(diag[1], torch.tensor(1.0), atol=1e-4)
169+
170+
# Test lower triangle property
171+
with self.subTest("lower_triangle"):
172+
num_tasks = 5
173+
kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2)
174+
lower_tri = kernel._lower_triangle
175+
176+
# Number of lower triangular elements (excluding diagonal)
177+
expected_size = num_tasks * (num_tasks - 1) // 2
178+
self.assertEqual(lower_tri.shape[-1], expected_size)
179+
self.assertTrue((lower_tri >= 0).all())
180+
181+
# Test invalid prior type
182+
with self.subTest("invalid_prior_type"):
183+
with self.assertRaises(TypeError):
184+
PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior")
185+
186+
# Test covariance matrix properties
187+
with self.subTest("covar_matrix"):
188+
kernel = PositiveIndexKernel(num_tasks=5, rank=4)
189+
covar = kernel.covar_matrix
190+
191+
# Should be square
192+
self.assertEqual(covar.shape[-2], covar.shape[-1])
193+
194+
# Should be positive definite (all eigenvalues > 0)
195+
eigvals = torch.linalg.eigvalsh(covar)
196+
self.assertTrue((eigvals > 0).all())
197+
198+
# Should be symmetric
199+
self.assertAllClose(covar, covar.T, atol=1e-5)

0 commit comments

Comments
 (0)