Skip to content

Commit fb43775

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
ESS: Allowing diagonal covariance root with fixed indices (#2283)
Summary: This commit adds support for a diagonal covariance root in conjunction with fixed indices for ESS. This is not generally supported, as the root would have to be re-factorized. The diagonal case allows for an efficient implementation without re-factorization. Differential Revision: D55808235
1 parent 968e465 commit fb43775

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

botorch/utils/probability/lin_ess.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import torch
3939
from botorch.utils.sampling import PolytopeSampler
40+
from linear_operator.operators import DiagLinearOperator, LinearOperator
4041
from torch import Tensor
4142

4243
_twopi = 2.0 * math.pi
@@ -58,8 +59,8 @@ def __init__(
5859
interior_point: Optional[Tensor] = None,
5960
fixed_indices: Optional[Union[List[int], Tensor]] = None,
6061
mean: Optional[Tensor] = None,
61-
covariance_matrix: Optional[Tensor] = None,
62-
covariance_root: Optional[Tensor] = None,
62+
covariance_matrix: Optional[Union[Tensor, LinearOperator]] = None,
63+
covariance_root: Optional[Union[Tensor, LinearOperator]] = None,
6364
check_feasibility: bool = False,
6465
burnin: int = 0,
6566
thinning: int = 0,
@@ -88,7 +89,10 @@ def __init__(
8889
distribution (if omitted, use the identity).
8990
covariance_root: A `d x d`-dim root of the covariance matrix such that
9091
covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
91-
matrix is assumed to be lower triangular.
92+
matrix is assumed to be lower triangular. covariance_root can only be
93+
passed in conjunction with fixed_indices if covariance_root is a
94+
DiagLinearOperator. Otherwise the factorization would need to be re-
95+
computed, as we need to solve in `standardize`.
9296
check_feasibility: If True, raise an error if the sampling results in an
9397
infeasible sample. This creates some overhead and so is switched off
9498
by default.
@@ -123,14 +127,16 @@ def __init__(
123127
self._Az, self._bz = A, b
124128
self._is_fixed, self._not_fixed = None, None
125129
if fixed_indices is not None:
126-
mean, covariance_matrix = self._fixed_features_initialization(
127-
A=A,
128-
b=b,
129-
interior_point=interior_point,
130-
fixed_indices=fixed_indices,
131-
mean=mean,
132-
covariance_matrix=covariance_matrix,
133-
covariance_root=covariance_root,
130+
mean, covariance_matrix, covariance_root = (
131+
self._fixed_features_initialization(
132+
A=A,
133+
b=b,
134+
interior_point=interior_point,
135+
fixed_indices=fixed_indices,
136+
mean=mean,
137+
covariance_matrix=covariance_matrix,
138+
covariance_root=covariance_root,
139+
)
134140
)
135141

136142
self._mean = mean
@@ -185,7 +191,8 @@ def _fixed_features_initialization(
185191
"If `fixed_indices` are provided, an interior point must also be "
186192
"provided in order to infer feasible values of the fixed features."
187193
)
188-
if covariance_root is not None:
194+
root_is_diag = isinstance(covariance_root, DiagLinearOperator)
195+
if covariance_root is not None and not root_is_diag:
189196
raise ValueError(
190197
"Provide either covariance_root or fixed_indices, not both."
191198
)
@@ -205,7 +212,10 @@ def _fixed_features_initialization(
205212
covariance_matrix = covariance_matrix[
206213
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
207214
]
208-
return mean, covariance_matrix
215+
if root_is_diag: # in the special case of diagonal root, can subselect
216+
covariance_root = DiagLinearOperator(covariance_root.diagonal()[not_fixed])
217+
218+
return mean, covariance_matrix, covariance_root
209219

210220
def _standardization_initialization(self) -> None:
211221
"""For non-standard mean and covariance, we're going to rewrite the problem as
@@ -482,8 +492,10 @@ def _standardize(self, x: Tensor) -> Tensor:
482492
z = x
483493
if self._mean is not None:
484494
z = z - self._mean
485-
if self._covariance_root is not None:
486-
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
495+
root = self._covariance_root
496+
if root is not None:
497+
z = torch.linalg.solve_triangular(root, z, upper=False)
498+
487499
return z
488500

489501
def _unstandardize(self, z: Tensor) -> Tensor:

test/utils/probability/test_lin_ess.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.utils.constraints import get_monotonicity_constraints
1818
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
1919
from botorch.utils.testing import BotorchTestCase
20+
from linear_operator.operators import DiagLinearOperator
2021
from torch import Tensor
2122

2223

@@ -431,6 +432,19 @@ def test_multivariate(self):
431432
covariance_root=torch.eye(d, **tkwargs),
432433
)
433434

435+
# providing a diagonal covariance_root should work with fixed indices
436+
torch.manual_seed(1234)
437+
sampler = LinearEllipticalSliceSampler(
438+
inequality_constraints=(A, b),
439+
interior_point=interior_point,
440+
fixed_indices=[0],
441+
covariance_root=DiagLinearOperator(torch.full((d,), 100, **tkwargs)),
442+
)
443+
num_samples = 16
444+
X_fixed = sampler.draw(n=num_samples)
445+
self.assertTrue((X_fixed[:, 0] == interior_point[0]).all())
446+
self.assertGreater(X_fixed.std().item(), 10.0) # false if sigma = 1
447+
434448
# high dimensional test case
435449
# Encodes order constraints on all d variables: Ax < b <-> x[i] < x[i + 1]
436450
d = 128

0 commit comments

Comments
 (0)