Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from botorch.exceptions.errors import InputDataError, UnsupportedError
from botorch.utils.containers import BotorchContainer, SliceContainer
from pyre_extensions import none_throws
from torch import long, ones, Tensor


Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
outcome_names: list[str],
Yvar: BotorchContainer | Tensor | None = None,
validate_init: bool = True,
group_indices: Tensor | None = None,
) -> None:
r"""Constructs a `SupervisedDataset`.

Expand All @@ -66,12 +68,17 @@ def __init__(
Yvar: An optional `Tensor` or `BotorchContainer` representing
the observation noise.
validate_init: If `True`, validates the input shapes.
group_indices: A `Tensor` representing the which rows of X and Y are
grouped together. This is used to support applications in which multiple
observations should be considered as a group, e.g., learning-curve-based
modeling. If provided, its shape must be compatible with X and Y.
"""
self._X = X
self._Y = Y
self._Yvar = Yvar
self.feature_names = feature_names
self.outcome_names = outcome_names
self.group_indices = group_indices
self.validate_init = validate_init
if validate_init:
self._validate()
Expand All @@ -98,6 +105,7 @@ def _validate(
self,
validate_feature_names: bool = True,
validate_outcome_names: bool = True,
validate_group_indices: bool = True,
) -> None:
r"""Checks that the shapes of the inputs are compatible with each other.

Expand All @@ -110,6 +118,8 @@ def _validate(
`outcomes_names` matches the # of columns of `self.Y`. If a
particular dataset, e.g., `RankingDataset`, is known to violate
this assumption, this can be set to `False`.
validate_group_indices: By default, we validate that the shape of
`group_indices` matches the shape of X and Y.
"""
shape_X = self.X.shape
if isinstance(self._X, BotorchContainer):
Expand All @@ -135,8 +145,20 @@ def _validate(
"`Y` must have the same number of columns as the number of "
"outcomes in `outcome_names`."
)
if validate_group_indices and self.group_indices is not None:
if self.group_indices.shape != shape_X:
raise ValueError(
f"shape_X ({shape_X}) must have the same shape as "
f"group_indices ({none_throws(self.group_indices).shape})."
)

def __eq__(self, other: Any) -> bool:
if self.group_indices is None and other.group_indices is None:
group_indices_equal = True
elif self.group_indices is None or other.group_indices is None:
group_indices_equal = False
else:
group_indices_equal = torch.equal(self.group_indices, other.group_indices)
return (
type(other) is type(self)
and torch.equal(self.X, other.X)
Expand All @@ -148,6 +170,7 @@ def __eq__(self, other: Any) -> bool:
)
and self.feature_names == other.feature_names
and self.outcome_names == other.outcome_names
and group_indices_equal
)

def clone(
Expand Down Expand Up @@ -256,7 +279,11 @@ def __init__(
)

def _validate(self) -> None:
super()._validate(validate_feature_names=False, validate_outcome_names=False)
super()._validate(
validate_feature_names=False,
validate_outcome_names=False,
validate_group_indices=False,
)
if len(self.feature_names) != self._X.values.shape[-1]:
raise ValueError(
"The `values` field of `X` must have the same number of columns as "
Expand Down Expand Up @@ -331,6 +358,7 @@ def __init__(
self.has_heterogeneous_features = any(
datasets[0].feature_names != ds.feature_names for ds in datasets[1:]
)
self.group_indices = None

@classmethod
def from_joint_dataset(
Expand Down Expand Up @@ -584,6 +612,7 @@ def __init__(
c: [self.feature_names.index(i) for i in parameter_decomposition[c]]
for c in self.context_buckets
}
self.group_indices = None

@property
def X(self) -> Tensor:
Expand Down
22 changes: 19 additions & 3 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,27 +98,35 @@ def make_contextual_dataset(
class TestDatasets(BotorchTestCase):
def test_supervised(self):
# Generate some data
X = rand(3, 2)
Y = rand(3, 1)
n_rows = 3
X = rand(n_rows, 2)
Y = rand(n_rows, 1)
feature_names = ["x1", "x2"]
outcome_names = ["y"]
group_indices = tensor(range(n_rows))

# Test `__init__`
dataset = SupervisedDataset(
X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names
X=X,
Y=Y,
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=group_indices,
)
self.assertIsInstance(dataset.X, Tensor)
self.assertIsInstance(dataset._X, Tensor)
self.assertIsInstance(dataset.Y, Tensor)
self.assertIsInstance(dataset._Y, Tensor)
self.assertEqual(dataset.feature_names, feature_names)
self.assertEqual(dataset.outcome_names, outcome_names)
self.assertTrue(torch.equal(dataset.group_indices, group_indices))

dataset2 = SupervisedDataset(
X=DenseContainer(X, X.shape[-1:]),
Y=DenseContainer(Y, Y.shape[-1:]),
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=group_indices,
)
self.assertIsInstance(dataset2.X, Tensor)
self.assertIsInstance(dataset2._X, DenseContainer)
Expand Down Expand Up @@ -156,6 +164,14 @@ def test_supervised(self):
feature_names=feature_names,
outcome_names=[],
)
with self.assertRaisesRegex(ValueError, "group_indices"):
SupervisedDataset(
X=rand(2, 2),
Y=rand(2, 1),
feature_names=feature_names,
outcome_names=outcome_names,
group_indices=tensor(range(n_rows + 1)),
)

# Test with Yvar.
dataset = SupervisedDataset(
Expand Down