Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions botorch/utils/probability/lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import torch
from botorch.utils.sampling import PolytopeSampler
from linear_operator.operators import DiagLinearOperator, LinearOperator
from torch import Tensor

_twopi = 2.0 * math.pi
Expand All @@ -58,8 +59,8 @@ def __init__(
interior_point: Optional[Tensor] = None,
fixed_indices: Optional[Union[List[int], Tensor]] = None,
mean: Optional[Tensor] = None,
covariance_matrix: Optional[Tensor] = None,
covariance_root: Optional[Tensor] = None,
covariance_matrix: Optional[Union[Tensor, LinearOperator]] = None,
covariance_root: Optional[Union[Tensor, LinearOperator]] = None,
check_feasibility: bool = False,
burnin: int = 0,
thinning: int = 0,
Expand Down Expand Up @@ -88,7 +89,10 @@ def __init__(
distribution (if omitted, use the identity).
covariance_root: A `d x d`-dim root of the covariance matrix such that
covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
matrix is assumed to be lower triangular.
matrix is assumed to be lower triangular. covariance_root can only be
passed in conjunction with fixed_indices if covariance_root is a
DiagLinearOperator. Otherwise the factorization would need to be re-
computed, as we need to solve in `standardize`.
check_feasibility: If True, raise an error if the sampling results in an
infeasible sample. This creates some overhead and so is switched off
by default.
Expand Down Expand Up @@ -123,14 +127,16 @@ def __init__(
self._Az, self._bz = A, b
self._is_fixed, self._not_fixed = None, None
if fixed_indices is not None:
mean, covariance_matrix = self._fixed_features_initialization(
A=A,
b=b,
interior_point=interior_point,
fixed_indices=fixed_indices,
mean=mean,
covariance_matrix=covariance_matrix,
covariance_root=covariance_root,
mean, covariance_matrix, covariance_root = (
self._fixed_features_initialization(
A=A,
b=b,
interior_point=interior_point,
fixed_indices=fixed_indices,
mean=mean,
covariance_matrix=covariance_matrix,
covariance_root=covariance_root,
)
)

self._mean = mean
Expand Down Expand Up @@ -176,6 +182,9 @@ def _fixed_features_initialization(
"""Modifies the constraint system (A, b) due to fixed indices and assigns
the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
called prior to `self._standardization_initialization` in the constructor.
covariance_root and fixed_indices can both not be None only if covariance_root
is a DiagLinearOperator. Otherwise, the covariance matrix would need to be
refactorized.

Returns:
Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
Expand All @@ -185,10 +194,16 @@ def _fixed_features_initialization(
"If `fixed_indices` are provided, an interior point must also be "
"provided in order to infer feasible values of the fixed features."
)
if covariance_root is not None:
raise ValueError(
"Provide either covariance_root or fixed_indices, not both."
)

root_is_diag = isinstance(covariance_root, DiagLinearOperator)
if covariance_root is not None and not root_is_diag:
root_is_diag = (covariance_root.diag().diag() == covariance_root).all()
if root_is_diag: # convert the diagonal root to a DiagLinearOperator
covariance_root = DiagLinearOperator(covariance_root.diagonal())
else: # otherwise, fail
raise ValueError(
"Provide either covariance_root or fixed_indices, not both."
)
d = interior_point.shape[0]
is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d)
self._is_fixed = is_fixed
Expand All @@ -205,7 +220,10 @@ def _fixed_features_initialization(
covariance_matrix = covariance_matrix[
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
]
return mean, covariance_matrix
if root_is_diag: # in the special case of diagonal root, can subselect
covariance_root = DiagLinearOperator(covariance_root.diagonal()[not_fixed])

return mean, covariance_matrix, covariance_root

def _standardization_initialization(self) -> None:
"""For non-standard mean and covariance, we're going to rewrite the problem as
Expand Down Expand Up @@ -482,8 +500,10 @@ def _standardize(self, x: Tensor) -> Tensor:
z = x
if self._mean is not None:
z = z - self._mean
if self._covariance_root is not None:
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
root = self._covariance_root
if root is not None:
z = torch.linalg.solve_triangular(root, z, upper=False)

return z

def _unstandardize(self, z: Tensor) -> Tensor:
Expand Down
18 changes: 17 additions & 1 deletion test/utils/probability/test_lin_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from botorch.utils.constraints import get_monotonicity_constraints
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
from botorch.utils.testing import BotorchTestCase
from linear_operator.operators import DiagLinearOperator
from torch import Tensor


Expand Down Expand Up @@ -428,9 +429,24 @@ def test_multivariate(self):
inequality_constraints=(A, b),
interior_point=interior_point,
fixed_indices=[0],
covariance_root=torch.eye(d, **tkwargs),
covariance_root=torch.full((d, d), 100, **tkwargs),
)

# providing a diagonal covariance_root should work with fixed indices
diag_root = torch.full((d,), 100, **tkwargs)
for covariance_root in [DiagLinearOperator(diag_root), diag_root.diag()]:
torch.manual_seed(1234)
sampler = LinearEllipticalSliceSampler(
inequality_constraints=(A, b),
interior_point=interior_point,
fixed_indices=[0],
covariance_root=covariance_root,
)
num_samples = 16
X_fixed = sampler.draw(n=num_samples)
self.assertTrue((X_fixed[:, 0] == interior_point[0]).all())
self.assertGreater(X_fixed.std().item(), 10.0) # false if sigma = 1

# high dimensional test case
# Encodes order constraints on all d variables: Ax < b <-> x[i] < x[i + 1]
d = 128
Expand Down