Skip to content

Commit 2b4054d

Browse files
introduce trial_indices argument to SupervisedDataset (#2595)
Summary: X-link: facebook/Ax#2960 Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)). Reviewed By: Balandat Differential Revision: D64764019
1 parent 9d37e90 commit 2b4054d

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

botorch/utils/datasets.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
from botorch.exceptions.errors import InputDataError, UnsupportedError
1616
from botorch.utils.containers import BotorchContainer, SliceContainer
17+
from pyre_extensions import none_throws
1718
from torch import long, ones, Tensor
1819

1920

@@ -54,6 +55,7 @@ def __init__(
5455
outcome_names: list[str],
5556
Yvar: BotorchContainer | Tensor | None = None,
5657
validate_init: bool = True,
58+
trial_indices: Tensor | None = None,
5759
) -> None:
5860
r"""Constructs a `SupervisedDataset`.
5961
@@ -65,12 +67,16 @@ def __init__(
6567
Yvar: An optional `Tensor` or `BotorchContainer` representing
6668
the observation noise.
6769
validate_init: If `True`, validates the input shapes.
70+
trial_indices: A `Tensor` representing the trial indices of X and Y. This is
71+
used to support learning-curve-based modeling. If provided, it must
72+
have compatible shape with X and Y.
6873
"""
6974
self._X = X
7075
self._Y = Y
7176
self._Yvar = Yvar
7277
self.feature_names = feature_names
7378
self.outcome_names = outcome_names
79+
self.trial_indices = trial_indices
7480
if validate_init:
7581
self._validate()
7682

@@ -96,6 +102,7 @@ def _validate(
96102
self,
97103
validate_feature_names: bool = True,
98104
validate_outcome_names: bool = True,
105+
validate_trial_indices: bool = True,
99106
) -> None:
100107
r"""Checks that the shapes of the inputs are compatible with each other.
101108
@@ -108,6 +115,8 @@ def _validate(
108115
`outcomes_names` matches the # of columns of `self.Y`. If a
109116
particular dataset, e.g., `RankingDataset`, is known to violate
110117
this assumption, this can be set to `False`.
118+
validate_trial_indices: By default, we validate that the shape of
119+
`trial_indices` matches the shape of X and Y.
111120
"""
112121
shape_X = self.X.shape
113122
if isinstance(self._X, BotorchContainer):
@@ -133,8 +142,20 @@ def _validate(
133142
"`Y` must have the same number of columns as the number of "
134143
"outcomes in `outcome_names`."
135144
)
145+
if validate_trial_indices and self.trial_indices is not None:
146+
if self.trial_indices.shape != shape_X:
147+
raise ValueError(
148+
f"shape_X ({shape_X}) must have the same shape as "
149+
f"trial_indices ({none_throws(self.trial_indices).shape})."
150+
)
136151

137152
def __eq__(self, other: Any) -> bool:
153+
if self.trial_indices is None and other.trial_indices is None:
154+
trial_indices_equal = True
155+
elif self.trial_indices is None or other.trial_indices is None:
156+
trial_indices_equal = False
157+
else:
158+
trial_indices_equal = torch.equal(self.trial_indices, other.trial_indices)
138159
return (
139160
type(other) is type(self)
140161
and torch.equal(self.X, other.X)
@@ -146,6 +167,7 @@ def __eq__(self, other: Any) -> bool:
146167
)
147168
and self.feature_names == other.feature_names
148169
and self.outcome_names == other.outcome_names
170+
and trial_indices_equal
149171
)
150172

151173

@@ -241,7 +263,11 @@ def __init__(
241263
)
242264

243265
def _validate(self) -> None:
244-
super()._validate(validate_feature_names=False, validate_outcome_names=False)
266+
super()._validate(
267+
validate_feature_names=False,
268+
validate_outcome_names=False,
269+
validate_trial_indices=False,
270+
)
245271
if len(self.feature_names) != self._X.values.shape[-1]:
246272
raise ValueError(
247273
"The `values` field of `X` must have the same number of columns as "
@@ -316,6 +342,7 @@ def __init__(
316342
self.has_heterogeneous_features = any(
317343
datasets[0].feature_names != ds.feature_names for ds in datasets[1:]
318344
)
345+
self.trial_indices = None
319346

320347
@classmethod
321348
def from_joint_dataset(
@@ -538,6 +565,7 @@ def __init__(
538565
c: [self.feature_names.index(i) for i in parameter_decomposition[c]]
539566
for c in self.context_buckets
540567
}
568+
self.trial_indices = None
541569

542570
@property
543571
def X(self) -> Tensor:

test/utils/test_datasets.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,27 +43,35 @@ def make_dataset(
4343
class TestDatasets(BotorchTestCase):
4444
def test_supervised(self):
4545
# Generate some data
46-
X = rand(3, 2)
47-
Y = rand(3, 1)
46+
n_rows = 3
47+
X = rand(n_rows, 2)
48+
Y = rand(n_rows, 1)
4849
feature_names = ["x1", "x2"]
4950
outcome_names = ["y"]
51+
trial_indices = tensor(range(n_rows))
5052

5153
# Test `__init__`
5254
dataset = SupervisedDataset(
53-
X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names
55+
X=X,
56+
Y=Y,
57+
feature_names=feature_names,
58+
outcome_names=outcome_names,
59+
trial_indices=trial_indices,
5460
)
5561
self.assertIsInstance(dataset.X, Tensor)
5662
self.assertIsInstance(dataset._X, Tensor)
5763
self.assertIsInstance(dataset.Y, Tensor)
5864
self.assertIsInstance(dataset._Y, Tensor)
5965
self.assertEqual(dataset.feature_names, feature_names)
6066
self.assertEqual(dataset.outcome_names, outcome_names)
67+
self.assertTrue(torch.equal(dataset.trial_indices, trial_indices))
6168

6269
dataset2 = SupervisedDataset(
6370
X=DenseContainer(X, X.shape[-1:]),
6471
Y=DenseContainer(Y, Y.shape[-1:]),
6572
feature_names=feature_names,
6673
outcome_names=outcome_names,
74+
trial_indices=trial_indices,
6775
)
6876
self.assertIsInstance(dataset2.X, Tensor)
6977
self.assertIsInstance(dataset2._X, DenseContainer)
@@ -101,6 +109,14 @@ def test_supervised(self):
101109
feature_names=feature_names,
102110
outcome_names=[],
103111
)
112+
with self.assertRaisesRegex(ValueError, "trial_indices"):
113+
SupervisedDataset(
114+
X=rand(2, 2),
115+
Y=rand(2, 1),
116+
feature_names=feature_names,
117+
outcome_names=outcome_names,
118+
trial_indices=tensor(range(n_rows + 1)),
119+
)
104120

105121
# Test with Yvar.
106122
dataset = SupervisedDataset(

0 commit comments

Comments
 (0)