|
30 | 30 | from gpytorch.likelihoods import ( |
31 | 31 | FixedNoiseGaussianLikelihood, |
32 | 32 | GaussianLikelihood, |
| 33 | + HadamardGaussianLikelihood, |
33 | 34 | MultitaskGaussianLikelihood, |
34 | 35 | ) |
35 | 36 | from gpytorch.means import ConstantMean, MultitaskMean |
@@ -154,7 +155,9 @@ def test_MultiTaskGP(self) -> None: |
154 | 155 | if fixed_noise: |
155 | 156 | self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood) |
156 | 157 | else: |
157 | | - self.assertIsInstance(model.likelihood, GaussianLikelihood) |
| 158 | + self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood) |
| 159 | + self.assertEqual(model.likelihood.noise.shape, torch.Size([2])) |
| 160 | + self.assertEqual(model.likelihood.task_feature_index, 0) |
158 | 161 | data_covar_module, task_covar_module = model.covar_module.kernels |
159 | 162 | self.assertIsInstance(model.mean_module, ConstantMean) |
160 | 163 | self.assertIsInstance(data_covar_module, RBFKernel) |
@@ -195,8 +198,8 @@ def test_MultiTaskGP(self) -> None: |
195 | 198 | torch.tensor([0.05, 0.1], **tkwargs).repeat_interleave(2) |
196 | 199 | ).expand(3, 4, 4) |
197 | 200 | else: |
198 | | - noise_covar = model.likelihood.noise_covar.noise * torch.eye( |
199 | | - 4, **tkwargs |
| 201 | + noise_covar = torch.diag( |
| 202 | + model.likelihood.noise_covar.noise.repeat_interleave(2) |
200 | 203 | ).expand(3, 4, 4) |
201 | 204 | expected_y_covar = posterior_f.covariance_matrix + noise_covar |
202 | 205 | self.assertTrue( |
@@ -337,7 +340,7 @@ def test_MultiTaskGP_single_output(self) -> None: |
337 | 340 | data_covar_module, task_covar_module = model.covar_module.kernels |
338 | 341 | self.assertIsInstance(model, MultiTaskGP) |
339 | 342 | self.assertEqual(model.num_outputs, 1) |
340 | | - self.assertIsInstance(model.likelihood, GaussianLikelihood) |
| 343 | + self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood) |
341 | 344 | self.assertIsInstance(model.mean_module, ConstantMean) |
342 | 345 | self.assertIsInstance(data_covar_module, RBFKernel) |
343 | 346 | self.assertIsInstance(data_covar_module.lengthscale_prior, LogNormalPrior) |
|
0 commit comments