|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the MIT license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import torch | 
|  | 8 | +from botorch.models.kernels.positive_index import PositiveIndexKernel | 
|  | 9 | +from botorch.utils.testing import BotorchTestCase | 
|  | 10 | +from gpytorch.priors import NormalPrior | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +class TestPositiveIndexKernel(BotorchTestCase): | 
|  | 14 | +    def test_positive_index_kernel(self): | 
|  | 15 | +        # Test initialization | 
|  | 16 | +        with self.subTest("basic_initialization"): | 
|  | 17 | +            num_tasks = 4 | 
|  | 18 | +            rank = 2 | 
|  | 19 | +            kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=rank) | 
|  | 20 | + | 
|  | 21 | +            self.assertEqual(kernel.num_tasks, num_tasks) | 
|  | 22 | +            self.assertEqual(kernel.raw_covar_factor.shape, (num_tasks, rank)) | 
|  | 23 | +            self.assertEqual(kernel.normalize_covar_matrix, False) | 
|  | 24 | + | 
|  | 25 | +        # Test initialization with batch shape | 
|  | 26 | +        with self.subTest("initialization_with_batch_shape"): | 
|  | 27 | +            num_tasks = 3 | 
|  | 28 | +            rank = 2 | 
|  | 29 | +            batch_shape = torch.Size([2]) | 
|  | 30 | +            kernel = PositiveIndexKernel( | 
|  | 31 | +                num_tasks=num_tasks, rank=rank, batch_shape=batch_shape | 
|  | 32 | +            ) | 
|  | 33 | + | 
|  | 34 | +            self.assertEqual(kernel.raw_covar_factor.shape, (2, num_tasks, rank)) | 
|  | 35 | + | 
|  | 36 | +        # Test rank validation | 
|  | 37 | +        with self.subTest("rank_validation"): | 
|  | 38 | +            num_tasks = 3 | 
|  | 39 | +            rank = 5 | 
|  | 40 | +            with self.assertRaises(RuntimeError): | 
|  | 41 | +                PositiveIndexKernel(num_tasks=num_tasks, rank=rank) | 
|  | 42 | + | 
|  | 43 | +        # Test target_task_index validation | 
|  | 44 | +        with self.subTest("target_task_index_validation"): | 
|  | 45 | +            num_tasks = 4 | 
|  | 46 | +            # Test invalid negative index | 
|  | 47 | +            with self.assertRaises(ValueError): | 
|  | 48 | +                PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=-1) | 
|  | 49 | +            # Test invalid index >= num_tasks | 
|  | 50 | +            with self.assertRaises(ValueError): | 
|  | 51 | +                PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=4) | 
|  | 52 | +            # Test valid indices (should not raise) | 
|  | 53 | +            PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=0) | 
|  | 54 | +            PositiveIndexKernel(num_tasks=num_tasks, rank=2, target_task_index=3) | 
|  | 55 | + | 
|  | 56 | +        # Test covar_factor constraint | 
|  | 57 | +        with self.subTest("positive_correlations"): | 
|  | 58 | +            kernel = PositiveIndexKernel(num_tasks=5, rank=3) | 
|  | 59 | +            covar_factor = kernel.covar_factor | 
|  | 60 | + | 
|  | 61 | +            # All elements should be positive | 
|  | 62 | +            self.assertTrue((covar_factor > 0).all()) | 
|  | 63 | + | 
|  | 64 | +            self.assertTrue((kernel.covar_matrix >= 0).all()) | 
|  | 65 | + | 
|  | 66 | +        # Test covariance matrix normalization (default target_task_index=0) | 
|  | 67 | +        with self.subTest("covar_matrix_normalization_default"): | 
|  | 68 | +            kernel = PositiveIndexKernel(num_tasks=4, rank=2) | 
|  | 69 | +            covar = kernel.covar_matrix | 
|  | 70 | + | 
|  | 71 | +            # First diagonal element should be 1.0 (normalized by default) | 
|  | 72 | +            self.assertAllClose(covar[0, 0], torch.tensor(1.0), atol=1e-4) | 
|  | 73 | + | 
|  | 74 | +        # Test covariance matrix normalization with custom target_task_index | 
|  | 75 | +        with self.subTest("covar_matrix_normalization_custom_target"): | 
|  | 76 | +            kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=2) | 
|  | 77 | +            covar = kernel.covar_matrix | 
|  | 78 | + | 
|  | 79 | +            # Third diagonal element should be 1.0 (target_task_index=2) | 
|  | 80 | +            self.assertAllClose(covar[2, 2], torch.tensor(1.0), atol=1e-4) | 
|  | 81 | + | 
|  | 82 | +            # Other diagonal elements should not be 1.0 | 
|  | 83 | +            self.assertNotEqual(covar[0, 0].item(), 1.0) | 
|  | 84 | + | 
|  | 85 | +        # Test forward pass shape | 
|  | 86 | +        with self.subTest("forward"): | 
|  | 87 | +            num_tasks = 4 | 
|  | 88 | +            kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2) | 
|  | 89 | +            kernel.eval() | 
|  | 90 | + | 
|  | 91 | +            i1 = torch.tensor([[0, 1], [2, 3]], dtype=torch.long) | 
|  | 92 | +            i2 = torch.tensor([[1, 2]], dtype=torch.long) | 
|  | 93 | + | 
|  | 94 | +            result = kernel(i1, i2) | 
|  | 95 | +            self.assertEqual(result.shape, torch.Size([2, 1])) | 
|  | 96 | +            num_tasks = 3 | 
|  | 97 | +            kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=1) | 
|  | 98 | +            kernel.eval() | 
|  | 99 | + | 
|  | 100 | +            kernel.initialize(raw_covar_factor=torch.ones(num_tasks, 1)) | 
|  | 101 | +            i1 = torch.tensor([[0]], dtype=torch.long) | 
|  | 102 | +            i2 = torch.tensor([[1]], dtype=torch.long) | 
|  | 103 | + | 
|  | 104 | +            result = kernel(i1, i2).to_dense() | 
|  | 105 | +            covar_matrix = kernel.covar_matrix | 
|  | 106 | +            expected = covar_matrix[0, 1] | 
|  | 107 | + | 
|  | 108 | +            self.assertAllClose(result.squeeze(), expected) | 
|  | 109 | + | 
|  | 110 | +        # Test with priors | 
|  | 111 | +        with self.subTest("with_priors"): | 
|  | 112 | +            num_tasks = 4 | 
|  | 113 | +            task_prior = NormalPrior(0, 1) | 
|  | 114 | +            diag_prior = NormalPrior(1, 0.1) | 
|  | 115 | + | 
|  | 116 | +            kernel = PositiveIndexKernel( | 
|  | 117 | +                num_tasks=num_tasks, | 
|  | 118 | +                rank=2, | 
|  | 119 | +                task_prior=task_prior, | 
|  | 120 | +                diag_prior=diag_prior, | 
|  | 121 | +                initialize_to_mode=False, | 
|  | 122 | +            ) | 
|  | 123 | +            prior_names = [p[0] for p in kernel.named_priors()] | 
|  | 124 | +            self.assertIn("IndexKernelPrior", prior_names) | 
|  | 125 | +            self.assertIn("ScalePrior", prior_names) | 
|  | 126 | + | 
|  | 127 | +        # Test batch forward | 
|  | 128 | +        with self.subTest("batch_forward"): | 
|  | 129 | +            num_tasks = 3 | 
|  | 130 | +            batch_shape = torch.Size([2]) | 
|  | 131 | +            kernel = PositiveIndexKernel( | 
|  | 132 | +                num_tasks=num_tasks, rank=2, batch_shape=batch_shape | 
|  | 133 | +            ) | 
|  | 134 | +            kernel.eval() | 
|  | 135 | + | 
|  | 136 | +            i1 = torch.tensor([[[0], [1]]], dtype=torch.long) | 
|  | 137 | +            i2 = torch.tensor([[[1], [2]]], dtype=torch.long) | 
|  | 138 | + | 
|  | 139 | +            result = kernel(i1, i2) | 
|  | 140 | + | 
|  | 141 | +            # Check that batch dimensions are preserved | 
|  | 142 | +            self.assertEqual(result.shape[0], 2) | 
|  | 143 | + | 
|  | 144 | +        # Test diagonal property (default target_task_index=0) | 
|  | 145 | +        with self.subTest("diagonal"): | 
|  | 146 | +            kernel = PositiveIndexKernel(num_tasks=4, rank=2) | 
|  | 147 | +            diag = kernel._diagonal | 
|  | 148 | + | 
|  | 149 | +            self.assertEqual(diag.shape, torch.Size([4])) | 
|  | 150 | +            # First diagonal element should be 1.0 (default target_task_index=0) | 
|  | 151 | +            self.assertAllClose(diag[0], torch.tensor(1.0), atol=1e-4) | 
|  | 152 | + | 
|  | 153 | +            # Test diagonal property with custom target_task_index | 
|  | 154 | +            kernel = PositiveIndexKernel(num_tasks=4, rank=2, target_task_index=1) | 
|  | 155 | +            diag = kernel._diagonal | 
|  | 156 | + | 
|  | 157 | +            self.assertEqual(diag.shape, torch.Size([4])) | 
|  | 158 | +            # Second diagonal element should be 1.0 (target_task_index=1) | 
|  | 159 | +            self.assertAllClose(diag[1], torch.tensor(1.0), atol=1e-4) | 
|  | 160 | + | 
|  | 161 | +        # Test lower triangle property | 
|  | 162 | +        with self.subTest("lower_triangle"): | 
|  | 163 | +            num_tasks = 5 | 
|  | 164 | +            kernel = PositiveIndexKernel(num_tasks=num_tasks, rank=2) | 
|  | 165 | +            lower_tri = kernel._lower_triangle | 
|  | 166 | + | 
|  | 167 | +            # Number of lower triangular elements (excluding diagonal) | 
|  | 168 | +            expected_size = num_tasks * (num_tasks - 1) // 2 | 
|  | 169 | +            self.assertEqual(lower_tri.shape[-1], expected_size) | 
|  | 170 | +            self.assertTrue((lower_tri >= 0).all()) | 
|  | 171 | + | 
|  | 172 | +        # Test invalid prior type | 
|  | 173 | +        with self.subTest("invalid_prior_type"): | 
|  | 174 | +            with self.assertRaises(TypeError): | 
|  | 175 | +                PositiveIndexKernel(num_tasks=4, rank=2, task_prior="not_a_prior") | 
|  | 176 | + | 
|  | 177 | +        # Test covariance matrix properties | 
|  | 178 | +        with self.subTest("covar_matrix"): | 
|  | 179 | +            kernel = PositiveIndexKernel(num_tasks=5, rank=4) | 
|  | 180 | +            covar = kernel.covar_matrix | 
|  | 181 | + | 
|  | 182 | +            # Should be square | 
|  | 183 | +            self.assertEqual(covar.shape[-2], covar.shape[-1]) | 
|  | 184 | + | 
|  | 185 | +            # Should be positive definite (all eigenvalues > 0) | 
|  | 186 | +            eigvals = torch.linalg.eigvalsh(covar) | 
|  | 187 | +            self.assertTrue((eigvals > 0).all()) | 
|  | 188 | + | 
|  | 189 | +            # Should be symmetric | 
|  | 190 | +            self.assertAllClose(covar, covar.T, atol=1e-5) | 
|  | 191 | + | 
|  | 192 | +        # Test covar_factor setter and getter | 
|  | 193 | +        with self.subTest("covar_factor"): | 
|  | 194 | +            kernel = PositiveIndexKernel(num_tasks=3, rank=2) | 
|  | 195 | +            new_covar_factor = torch.ones(3, 2) * 2.0 | 
|  | 196 | +            kernel.covar_factor = new_covar_factor | 
|  | 197 | +            self.assertAllClose(kernel.covar_factor, new_covar_factor, atol=1e-5) | 
|  | 198 | + | 
|  | 199 | +            kernel = PositiveIndexKernel(num_tasks=3, rank=2) | 
|  | 200 | +            params = kernel._covar_factor_params(kernel) | 
|  | 201 | +            self.assertEqual(params.shape, torch.Size([3, 2])) | 
|  | 202 | +            self.assertTrue((params > 0).all()) | 
|  | 203 | + | 
|  | 204 | +            kernel = PositiveIndexKernel(num_tasks=3, rank=2) | 
|  | 205 | +            new_value = torch.ones(3, 2) * 3.0 | 
|  | 206 | +            kernel._covar_factor_closure(kernel, new_value) | 
|  | 207 | +            self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5) | 
0 commit comments