Skip to content

Commit 77f8bdc

Browse files
introduce trial_indices argument to SupervisedDataset (#2595)
Summary: X-link: facebook/Ax#2960 Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)). Reviewed By: Balandat Differential Revision: D64764019
1 parent 3f2e2c7 commit 77f8bdc

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

botorch/utils/datasets.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from botorch.exceptions.errors import InputDataError, UnsupportedError
1717
from botorch.utils.containers import BotorchContainer, SliceContainer
18+
from pyre_extensions import none_throws
1819
from torch import long, ones, Tensor
1920

2021

@@ -55,6 +56,7 @@ def __init__(
5556
outcome_names: list[str],
5657
Yvar: BotorchContainer | Tensor | None = None,
5758
validate_init: bool = True,
59+
trial_indices: Tensor | None = None,
5860
) -> None:
5961
r"""Constructs a `SupervisedDataset`.
6062
@@ -66,13 +68,16 @@ def __init__(
6668
Yvar: An optional `Tensor` or `BotorchContainer` representing
6769
the observation noise.
6870
validate_init: If `True`, validates the input shapes.
71+
trial_indices: A `Tensor` representing the trial indices of X and Y. This is
72+
used to support learning-curve-based modeling. If provided, it must
73+
have compatible shape with X and Y.
6974
"""
7075
self._X = X
7176
self._Y = Y
7277
self._Yvar = Yvar
7378
self.feature_names = feature_names
7479
self.outcome_names = outcome_names
75-
self.validate_init = validate_init
80+
self.trial_indices = trial_indices
7681
if validate_init:
7782
self._validate()
7883

@@ -98,6 +103,7 @@ def _validate(
98103
self,
99104
validate_feature_names: bool = True,
100105
validate_outcome_names: bool = True,
106+
validate_trial_indices: bool = True,
101107
) -> None:
102108
r"""Checks that the shapes of the inputs are compatible with each other.
103109
@@ -110,6 +116,8 @@ def _validate(
110116
`outcomes_names` matches the # of columns of `self.Y`. If a
111117
particular dataset, e.g., `RankingDataset`, is known to violate
112118
this assumption, this can be set to `False`.
119+
validate_trial_indices: By default, we validate that the shape of
120+
`trial_indices` matches the shape of X and Y.
113121
"""
114122
shape_X = self.X.shape
115123
if isinstance(self._X, BotorchContainer):
@@ -135,8 +143,20 @@ def _validate(
135143
"`Y` must have the same number of columns as the number of "
136144
"outcomes in `outcome_names`."
137145
)
146+
if validate_trial_indices and self.trial_indices is not None:
147+
if self.trial_indices.shape != shape_X:
148+
raise ValueError(
149+
f"shape_X ({shape_X}) must have the same shape as "
150+
f"trial_indices ({none_throws(self.trial_indices).shape})."
151+
)
138152

139153
def __eq__(self, other: Any) -> bool:
154+
if self.trial_indices is None and other.trial_indices is None:
155+
trial_indices_equal = True
156+
elif self.trial_indices is None or other.trial_indices is None:
157+
trial_indices_equal = False
158+
else:
159+
trial_indices_equal = torch.equal(self.trial_indices, other.trial_indices)
140160
return (
141161
type(other) is type(self)
142162
and torch.equal(self.X, other.X)
@@ -148,6 +168,7 @@ def __eq__(self, other: Any) -> bool:
148168
)
149169
and self.feature_names == other.feature_names
150170
and self.outcome_names == other.outcome_names
171+
and trial_indices_equal
151172
)
152173

153174
def clone(
@@ -256,7 +277,11 @@ def __init__(
256277
)
257278

258279
def _validate(self) -> None:
259-
super()._validate(validate_feature_names=False, validate_outcome_names=False)
280+
super()._validate(
281+
validate_feature_names=False,
282+
validate_outcome_names=False,
283+
validate_trial_indices=False,
284+
)
260285
if len(self.feature_names) != self._X.values.shape[-1]:
261286
raise ValueError(
262287
"The `values` field of `X` must have the same number of columns as "
@@ -331,6 +356,7 @@ def __init__(
331356
self.has_heterogeneous_features = any(
332357
datasets[0].feature_names != ds.feature_names for ds in datasets[1:]
333358
)
359+
self.trial_indices = None
334360

335361
@classmethod
336362
def from_joint_dataset(
@@ -584,6 +610,7 @@ def __init__(
584610
c: [self.feature_names.index(i) for i in parameter_decomposition[c]]
585611
for c in self.context_buckets
586612
}
613+
self.trial_indices = None
587614

588615
@property
589616
def X(self) -> Tensor:

test/utils/test_datasets.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,27 +98,35 @@ def make_contextual_dataset(
9898
class TestDatasets(BotorchTestCase):
9999
def test_supervised(self):
100100
# Generate some data
101-
X = rand(3, 2)
102-
Y = rand(3, 1)
101+
n_rows = 3
102+
X = rand(n_rows, 2)
103+
Y = rand(n_rows, 1)
103104
feature_names = ["x1", "x2"]
104105
outcome_names = ["y"]
106+
trial_indices = tensor(range(n_rows))
105107

106108
# Test `__init__`
107109
dataset = SupervisedDataset(
108-
X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names
110+
X=X,
111+
Y=Y,
112+
feature_names=feature_names,
113+
outcome_names=outcome_names,
114+
trial_indices=trial_indices,
109115
)
110116
self.assertIsInstance(dataset.X, Tensor)
111117
self.assertIsInstance(dataset._X, Tensor)
112118
self.assertIsInstance(dataset.Y, Tensor)
113119
self.assertIsInstance(dataset._Y, Tensor)
114120
self.assertEqual(dataset.feature_names, feature_names)
115121
self.assertEqual(dataset.outcome_names, outcome_names)
122+
self.assertTrue(torch.equal(dataset.trial_indices, trial_indices))
116123

117124
dataset2 = SupervisedDataset(
118125
X=DenseContainer(X, X.shape[-1:]),
119126
Y=DenseContainer(Y, Y.shape[-1:]),
120127
feature_names=feature_names,
121128
outcome_names=outcome_names,
129+
trial_indices=trial_indices,
122130
)
123131
self.assertIsInstance(dataset2.X, Tensor)
124132
self.assertIsInstance(dataset2._X, DenseContainer)
@@ -156,6 +164,14 @@ def test_supervised(self):
156164
feature_names=feature_names,
157165
outcome_names=[],
158166
)
167+
with self.assertRaisesRegex(ValueError, "trial_indices"):
168+
SupervisedDataset(
169+
X=rand(2, 2),
170+
Y=rand(2, 1),
171+
feature_names=feature_names,
172+
outcome_names=outcome_names,
173+
trial_indices=tensor(range(n_rows + 1)),
174+
)
159175

160176
# Test with Yvar.
161177
dataset = SupervisedDataset(

0 commit comments

Comments
 (0)