1515import torch
1616from botorch .exceptions .errors import InputDataError , UnsupportedError
1717from botorch .utils .containers import BotorchContainer , SliceContainer
18+ from pyre_extensions import none_throws
1819from torch import long , ones , Tensor
1920
2021
@@ -55,6 +56,7 @@ def __init__(
5556 outcome_names : list [str ],
5657 Yvar : BotorchContainer | Tensor | None = None ,
5758 validate_init : bool = True ,
59+ trial_indices : Tensor | None = None ,
5860 ) -> None :
5961 r"""Constructs a `SupervisedDataset`.
6062
@@ -66,13 +68,16 @@ def __init__(
6668 Yvar: An optional `Tensor` or `BotorchContainer` representing
6769 the observation noise.
6870 validate_init: If `True`, validates the input shapes.
71+ trial_indices: A `Tensor` representing the trial indices of X and Y. This is
72+ used to support learning-curve-based modeling. If provided, it must
73+ have compatible shape with X and Y.
6974 """
7075 self ._X = X
7176 self ._Y = Y
7277 self ._Yvar = Yvar
7378 self .feature_names = feature_names
7479 self .outcome_names = outcome_names
75- self .validate_init = validate_init
80+ self .trial_indices = trial_indices
7681 if validate_init :
7782 self ._validate ()
7883
@@ -98,6 +103,7 @@ def _validate(
98103 self ,
99104 validate_feature_names : bool = True ,
100105 validate_outcome_names : bool = True ,
106+ validate_trial_indices : bool = True ,
101107 ) -> None :
102108 r"""Checks that the shapes of the inputs are compatible with each other.
103109
@@ -110,6 +116,8 @@ def _validate(
110116 `outcomes_names` matches the # of columns of `self.Y`. If a
111117 particular dataset, e.g., `RankingDataset`, is known to violate
112118 this assumption, this can be set to `False`.
119+ validate_trial_indices: By default, we validate that the shape of
120+ `trial_indices` matches the shape of X and Y.
113121 """
114122 shape_X = self .X .shape
115123 if isinstance (self ._X , BotorchContainer ):
@@ -135,8 +143,20 @@ def _validate(
135143 "`Y` must have the same number of columns as the number of "
136144 "outcomes in `outcome_names`."
137145 )
146+ if validate_trial_indices and self .trial_indices is not None :
147+ if self .trial_indices .shape != shape_X :
148+ raise ValueError (
149+ f"shape_X ({ shape_X } ) must have the same shape as "
150+ f"trial_indices ({ none_throws (self .trial_indices ).shape } )."
151+ )
138152
139153 def __eq__ (self , other : Any ) -> bool :
154+ if self .trial_indices is None and other .trial_indices is None :
155+ trial_indices_equal = True
156+ elif self .trial_indices is None or other .trial_indices is None :
157+ trial_indices_equal = False
158+ else :
159+ trial_indices_equal = torch .equal (self .trial_indices , other .trial_indices )
140160 return (
141161 type (other ) is type (self )
142162 and torch .equal (self .X , other .X )
@@ -148,6 +168,7 @@ def __eq__(self, other: Any) -> bool:
148168 )
149169 and self .feature_names == other .feature_names
150170 and self .outcome_names == other .outcome_names
171+ and trial_indices_equal
151172 )
152173
153174 def clone (
@@ -256,7 +277,11 @@ def __init__(
256277 )
257278
258279 def _validate (self ) -> None :
259- super ()._validate (validate_feature_names = False , validate_outcome_names = False )
280+ super ()._validate (
281+ validate_feature_names = False ,
282+ validate_outcome_names = False ,
283+ validate_trial_indices = False ,
284+ )
260285 if len (self .feature_names ) != self ._X .values .shape [- 1 ]:
261286 raise ValueError (
262287 "The `values` field of `X` must have the same number of columns as "
@@ -331,6 +356,7 @@ def __init__(
331356 self .has_heterogeneous_features = any (
332357 datasets [0 ].feature_names != ds .feature_names for ds in datasets [1 :]
333358 )
359+ self .trial_indices = None
334360
335361 @classmethod
336362 def from_joint_dataset (
@@ -584,6 +610,7 @@ def __init__(
584610 c : [self .feature_names .index (i ) for i in parameter_decomposition [c ]]
585611 for c in self .context_buckets
586612 }
613+ self .trial_indices = None
587614
588615 @property
589616 def X (self ) -> Tensor :
0 commit comments