1414import torch
1515from botorch .exceptions .errors import InputDataError , UnsupportedError
1616from botorch .utils .containers import BotorchContainer , SliceContainer
17+ from pyre_extensions import none_throws
1718from torch import long , ones , Tensor
1819
1920
@@ -54,6 +55,7 @@ def __init__(
5455 outcome_names : list [str ],
5556 Yvar : BotorchContainer | Tensor | None = None ,
5657 validate_init : bool = True ,
58+ trial_indices : Tensor | None = None ,
5759 ) -> None :
5860 r"""Constructs a `SupervisedDataset`.
5961
@@ -65,12 +67,16 @@ def __init__(
6567 Yvar: An optional `Tensor` or `BotorchContainer` representing
6668 the observation noise.
6769 validate_init: If `True`, validates the input shapes.
70+ trial_indices: A `Tensor` representing the trial indices of X and Y. This is
71+ used to support learning-curve-based modeling. If provided, it must
72+ have compatible shape with X and Y.
6873 """
6974 self ._X = X
7075 self ._Y = Y
7176 self ._Yvar = Yvar
7277 self .feature_names = feature_names
7378 self .outcome_names = outcome_names
79+ self .trial_indices = trial_indices
7480 if validate_init :
7581 self ._validate ()
7682
@@ -96,6 +102,7 @@ def _validate(
96102 self ,
97103 validate_feature_names : bool = True ,
98104 validate_outcome_names : bool = True ,
105+ validate_trial_indices : bool = True ,
99106 ) -> None :
100107 r"""Checks that the shapes of the inputs are compatible with each other.
101108
@@ -108,6 +115,8 @@ def _validate(
108115 `outcomes_names` matches the # of columns of `self.Y`. If a
109116 particular dataset, e.g., `RankingDataset`, is known to violate
110117 this assumption, this can be set to `False`.
118+ validate_trial_indices: By default, we validate that the shape of
119+ `trial_indices` matches the shape of X and Y.
111120 """
112121 shape_X = self .X .shape
113122 if isinstance (self ._X , BotorchContainer ):
@@ -133,8 +142,19 @@ def _validate(
133142 "`Y` must have the same number of columns as the number of "
134143 "outcomes in `outcome_names`."
135144 )
145+ if validate_trial_indices and self .trial_indices is not None :
146+ if self .trial_indices .shape != shape_X :
147+ raise ValueError (
148+ f"{ shape_X = } must have the same shape as { none_throws (self .trial_indices ).shape = } ."
149+ )
136150
137151 def __eq__ (self , other : Any ) -> bool :
152+ if self .trial_indices is None and other .trial_indices is None :
153+ trial_indices_equal = True
154+ elif self .trial_indices is None or other .trial_indices is None :
155+ trial_indices_equal = False
156+ else :
157+ trial_indices_equal = torch .equal (self .trial_indices , other .trial_indices )
138158 return (
139159 type (other ) is type (self )
140160 and torch .equal (self .X , other .X )
@@ -146,6 +166,7 @@ def __eq__(self, other: Any) -> bool:
146166 )
147167 and self .feature_names == other .feature_names
148168 and self .outcome_names == other .outcome_names
169+ and trial_indices_equal
149170 )
150171
151172
@@ -241,7 +262,11 @@ def __init__(
241262 )
242263
243264 def _validate (self ) -> None :
244- super ()._validate (validate_feature_names = False , validate_outcome_names = False )
265+ super ()._validate (
266+ validate_feature_names = False ,
267+ validate_outcome_names = False ,
268+ validate_trial_indices = False ,
269+ )
245270 if len (self .feature_names ) != self ._X .values .shape [- 1 ]:
246271 raise ValueError (
247272 "The `values` field of `X` must have the same number of columns as "
@@ -316,6 +341,7 @@ def __init__(
316341 self .has_heterogeneous_features = any (
317342 datasets [0 ].feature_names != ds .feature_names for ds in datasets [1 :]
318343 )
344+ self .trial_indices = None
319345
320346 @classmethod
321347 def from_joint_dataset (
@@ -538,6 +564,7 @@ def __init__(
538564 c : [self .feature_names .index (i ) for i in parameter_decomposition [c ]]
539565 for c in self .context_buckets
540566 }
567+ self .trial_indices = None
541568
542569 @property
543570 def X (self ) -> Tensor :
0 commit comments