3737
3838import torch
3939from botorch .utils .sampling import PolytopeSampler
40+ from linear_operator .operators import DiagLinearOperator , LinearOperator
4041from torch import Tensor
4142
4243_twopi = 2.0 * math .pi
@@ -58,8 +59,8 @@ def __init__(
5859 interior_point : Optional [Tensor ] = None ,
5960 fixed_indices : Optional [Union [List [int ], Tensor ]] = None ,
6061 mean : Optional [Tensor ] = None ,
61- covariance_matrix : Optional [Tensor ] = None ,
62- covariance_root : Optional [Tensor ] = None ,
62+ covariance_matrix : Optional [Union [ Tensor , LinearOperator ] ] = None ,
63+ covariance_root : Optional [Union [ Tensor , LinearOperator ] ] = None ,
6364 check_feasibility : bool = False ,
6465 burnin : int = 0 ,
6566 thinning : int = 0 ,
@@ -88,7 +89,10 @@ def __init__(
8889 distribution (if omitted, use the identity).
8990 covariance_root: A `d x d`-dim root of the covariance matrix such that
9091 covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
91- matrix is assumed to be lower triangular.
92+ matrix is assumed to be lower triangular. covariance_root can only be
93+ passed in conjunction with fixed_indices if covariance_root is a
94+ DiagLinearOperator. Otherwise the factorization would need to be re-
95+ computed, as we need to solve in `standardize`.
9296 check_feasibility: If True, raise an error if the sampling results in an
9397 infeasible sample. This creates some overhead and so is switched off
9498 by default.
@@ -123,14 +127,16 @@ def __init__(
123127 self ._Az , self ._bz = A , b
124128 self ._is_fixed , self ._not_fixed = None , None
125129 if fixed_indices is not None :
126- mean , covariance_matrix = self ._fixed_features_initialization (
127- A = A ,
128- b = b ,
129- interior_point = interior_point ,
130- fixed_indices = fixed_indices ,
131- mean = mean ,
132- covariance_matrix = covariance_matrix ,
133- covariance_root = covariance_root ,
130+ mean , covariance_matrix , covariance_root = (
131+ self ._fixed_features_initialization (
132+ A = A ,
133+ b = b ,
134+ interior_point = interior_point ,
135+ fixed_indices = fixed_indices ,
136+ mean = mean ,
137+ covariance_matrix = covariance_matrix ,
138+ covariance_root = covariance_root ,
139+ )
134140 )
135141
136142 self ._mean = mean
@@ -185,7 +191,8 @@ def _fixed_features_initialization(
185191 "If `fixed_indices` are provided, an interior point must also be "
186192 "provided in order to infer feasible values of the fixed features."
187193 )
188- if covariance_root is not None :
194+ root_is_diag = isinstance (covariance_root , DiagLinearOperator )
195+ if covariance_root is not None and not root_is_diag :
189196 raise ValueError (
190197 "Provide either covariance_root or fixed_indices, not both."
191198 )
@@ -205,7 +212,10 @@ def _fixed_features_initialization(
205212 covariance_matrix = covariance_matrix [
206213 not_fixed .unsqueeze (- 1 ), not_fixed .unsqueeze (0 )
207214 ]
208- return mean , covariance_matrix
215+ if root_is_diag : # in the special case of diagonal root, can subselect
216+ covariance_root = DiagLinearOperator (covariance_root .diagonal ()[not_fixed ])
217+
218+ return mean , covariance_matrix , covariance_root
209219
210220 def _standardization_initialization (self ) -> None :
211221 """For non-standard mean and covariance, we're going to rewrite the problem as
@@ -482,8 +492,10 @@ def _standardize(self, x: Tensor) -> Tensor:
482492 z = x
483493 if self ._mean is not None :
484494 z = z - self ._mean
485- if self ._covariance_root is not None :
486- z = torch .linalg .solve_triangular (self ._covariance_root , z , upper = False )
495+ root = self ._covariance_root
496+ if root is not None :
497+ z = torch .linalg .solve_triangular (root , z , upper = False )
498+
487499 return z
488500
489501 def _unstandardize (self , z : Tensor ) -> Tensor :
0 commit comments