1414import torch
1515from botorch .exceptions .errors import InputDataError , UnsupportedError
1616from botorch .utils .containers import BotorchContainer , SliceContainer
17+ from pyre_extensions import none_throws
1718from torch import long , ones , Tensor
1819
1920
@@ -54,6 +55,7 @@ def __init__(
5455 outcome_names : list [str ],
5556 Yvar : BotorchContainer | Tensor | None = None ,
5657 validate_init : bool = True ,
58+ trial_indices : Tensor | None = None ,
5759 ) -> None :
5860 r"""Constructs a `SupervisedDataset`.
5961
@@ -65,12 +67,16 @@ def __init__(
6567 Yvar: An optional `Tensor` or `BotorchContainer` representing
6668 the observation noise.
6769 validate_init: If `True`, validates the input shapes.
70+ trial_indices: A `Tensor` representing the trial indices of X and Y. This is
71+ used to support learning-curve-based modeling. If provided, it must
72+ have compatible shape with X and Y.
6873 """
6974 self ._X = X
7075 self ._Y = Y
7176 self ._Yvar = Yvar
7277 self .feature_names = feature_names
7378 self .outcome_names = outcome_names
79+ self .trial_indices = trial_indices
7480 if validate_init :
7581 self ._validate ()
7682
@@ -96,6 +102,7 @@ def _validate(
96102 self ,
97103 validate_feature_names : bool = True ,
98104 validate_outcome_names : bool = True ,
105+ validate_trial_indices : bool = True ,
99106 ) -> None :
100107 r"""Checks that the shapes of the inputs are compatible with each other.
101108
@@ -108,6 +115,8 @@ def _validate(
108115 `outcomes_names` matches the # of columns of `self.Y`. If a
109116 particular dataset, e.g., `RankingDataset`, is known to violate
110117 this assumption, this can be set to `False`.
118+ validate_trial_indices: By default, we validate that the shape of
119+ `trial_indices` matches the shape of X and Y.
111120 """
112121 shape_X = self .X .shape
113122 if isinstance (self ._X , BotorchContainer ):
@@ -133,8 +142,20 @@ def _validate(
133142 "`Y` must have the same number of columns as the number of "
134143 "outcomes in `outcome_names`."
135144 )
145+ if validate_trial_indices and self .trial_indices is not None :
146+ if self .trial_indices .shape != shape_X :
147+ raise ValueError (
148+ f"shape_X ({ shape_X } ) must have the same shape as "
149+ f"trial_indices ({ none_throws (self .trial_indices ).shape } )."
150+ )
136151
137152 def __eq__ (self , other : Any ) -> bool :
153+ if self .trial_indices is None and other .trial_indices is None :
154+ trial_indices_equal = True
155+ elif self .trial_indices is None or other .trial_indices is None :
156+ trial_indices_equal = False
157+ else :
158+ trial_indices_equal = torch .equal (self .trial_indices , other .trial_indices )
138159 return (
139160 type (other ) is type (self )
140161 and torch .equal (self .X , other .X )
@@ -146,6 +167,7 @@ def __eq__(self, other: Any) -> bool:
146167 )
147168 and self .feature_names == other .feature_names
148169 and self .outcome_names == other .outcome_names
170+ and trial_indices_equal
149171 )
150172
151173
@@ -241,7 +263,11 @@ def __init__(
241263 )
242264
243265 def _validate (self ) -> None :
244- super ()._validate (validate_feature_names = False , validate_outcome_names = False )
266+ super ()._validate (
267+ validate_feature_names = False ,
268+ validate_outcome_names = False ,
269+ validate_trial_indices = False ,
270+ )
245271 if len (self .feature_names ) != self ._X .values .shape [- 1 ]:
246272 raise ValueError (
247273 "The `values` field of `X` must have the same number of columns as "
@@ -316,6 +342,7 @@ def __init__(
316342 self .has_heterogeneous_features = any (
317343 datasets [0 ].feature_names != ds .feature_names for ds in datasets [1 :]
318344 )
345+ self .trial_indices = None
319346
320347 @classmethod
321348 def from_joint_dataset (
@@ -538,6 +565,7 @@ def __init__(
538565 c : [self .feature_names .index (i ) for i in parameter_decomposition [c ]]
539566 for c in self .context_buckets
540567 }
568+ self .trial_indices = None
541569
542570 @property
543571 def X (self ) -> Tensor :
0 commit comments