3737
3838import torch
3939from botorch .utils .sampling import PolytopeSampler
40+ from linear_operator .operators import DiagLinearOperator , LinearOperator
4041from torch import Tensor
4142
4243_twopi = 2.0 * math .pi
@@ -58,8 +59,8 @@ def __init__(
5859 interior_point : Optional [Tensor ] = None ,
5960 fixed_indices : Optional [Union [List [int ], Tensor ]] = None ,
6061 mean : Optional [Tensor ] = None ,
61- covariance_matrix : Optional [Tensor ] = None ,
62- covariance_root : Optional [Tensor ] = None ,
62+ covariance_matrix : Optional [Union [ Tensor , LinearOperator ] ] = None ,
63+ covariance_root : Optional [Union [ Tensor , LinearOperator ] ] = None ,
6364 check_feasibility : bool = False ,
6465 burnin : int = 0 ,
6566 thinning : int = 0 ,
@@ -88,7 +89,10 @@ def __init__(
8889 distribution (if omitted, use the identity).
8990 covariance_root: A `d x d`-dim root of the covariance matrix such that
9091 covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
91- matrix is assumed to be lower triangular.
92+ matrix is assumed to be lower triangular. covariance_root can only be
93+ passed in conjunction with fixed_indices if covariance_root is a
94+ DiagLinearOperator. Otherwise the factorization would need to be re-
95+ computed, as we need to solve in `standardize`.
9296 check_feasibility: If True, raise an error if the sampling results in an
9397 infeasible sample. This creates some overhead and so is switched off
9498 by default.
@@ -123,14 +127,16 @@ def __init__(
123127 self ._Az , self ._bz = A , b
124128 self ._is_fixed , self ._not_fixed = None , None
125129 if fixed_indices is not None :
126- mean , covariance_matrix = self ._fixed_features_initialization (
127- A = A ,
128- b = b ,
129- interior_point = interior_point ,
130- fixed_indices = fixed_indices ,
131- mean = mean ,
132- covariance_matrix = covariance_matrix ,
133- covariance_root = covariance_root ,
130+ mean , covariance_matrix , covariance_root = (
131+ self ._fixed_features_initialization (
132+ A = A ,
133+ b = b ,
134+ interior_point = interior_point ,
135+ fixed_indices = fixed_indices ,
136+ mean = mean ,
137+ covariance_matrix = covariance_matrix ,
138+ covariance_root = covariance_root ,
139+ )
134140 )
135141
136142 self ._mean = mean
@@ -176,6 +182,9 @@ def _fixed_features_initialization(
176182 """Modifies the constraint system (A, b) due to fixed indices and assigns
177183 the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
178184 called prior to `self._standardization_initialization` in the constructor.
185+ covariance_root and fixed_indices can both not be None only if covariance_root
186+ is a DiagLinearOperator. Otherwise, the covariance matrix would need to be
187+ refactorized.
179188
180189 Returns:
181190 Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
@@ -185,10 +194,16 @@ def _fixed_features_initialization(
185194 "If `fixed_indices` are provided, an interior point must also be "
186195 "provided in order to infer feasible values of the fixed features."
187196 )
188- if covariance_root is not None :
189- raise ValueError (
190- "Provide either covariance_root or fixed_indices, not both."
191- )
197+
198+ root_is_diag = isinstance (covariance_root , DiagLinearOperator )
199+ if covariance_root is not None and not root_is_diag :
200+ root_is_diag = (covariance_root .diag ().diag () == covariance_root ).all ()
201+ if root_is_diag : # convert the diagonal root to a DiagLinearOperator
202+ covariance_root = DiagLinearOperator (covariance_root .diagonal ())
203+ else : # otherwise, fail
204+ raise ValueError (
205+ "Provide either covariance_root or fixed_indices, not both."
206+ )
192207 d = interior_point .shape [0 ]
193208 is_fixed , not_fixed = get_index_tensors (fixed_indices = fixed_indices , d = d )
194209 self ._is_fixed = is_fixed
@@ -205,7 +220,10 @@ def _fixed_features_initialization(
205220 covariance_matrix = covariance_matrix [
206221 not_fixed .unsqueeze (- 1 ), not_fixed .unsqueeze (0 )
207222 ]
208- return mean , covariance_matrix
223+ if root_is_diag : # in the special case of diagonal root, can subselect
224+ covariance_root = DiagLinearOperator (covariance_root .diagonal ()[not_fixed ])
225+
226+ return mean , covariance_matrix , covariance_root
209227
210228 def _standardization_initialization (self ) -> None :
211229 """For non-standard mean and covariance, we're going to rewrite the problem as
@@ -482,8 +500,10 @@ def _standardize(self, x: Tensor) -> Tensor:
482500 z = x
483501 if self ._mean is not None :
484502 z = z - self ._mean
485- if self ._covariance_root is not None :
486- z = torch .linalg .solve_triangular (self ._covariance_root , z , upper = False )
503+ root = self ._covariance_root
504+ if root is not None :
505+ z = torch .linalg .solve_triangular (root , z , upper = False )
506+
487507 return z
488508
489509 def _unstandardize (self , z : Tensor ) -> Tensor :
0 commit comments