Skip to content

Commit de46059

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add clone method to datasets (#2625)
Summary: Pull Request resolved: #2625 This makes it far easier to obtain slices of different kinds of datasets (Supervised, MultiTask, Contextual), which will be helpful for things like doing LOOCV MBM in Ax. Reviewed By: saitcakmak Differential Revision: D65616941 fbshipit-source-id: 2f121b6b950ee9b378a1cecc7f5693163264b743
1 parent 3c2ce15 commit de46059

File tree

4 files changed

+352
-57
lines changed

4 files changed

+352
-57
lines changed

botorch/utils/containers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
from __future__ import annotations
1010

11+
import dataclasses
12+
1113
from abc import ABC, abstractmethod
1214
from dataclasses import dataclass, fields
1315
from typing import Any
1416

17+
import torch
18+
1519
from torch import device as Device, dtype as Dtype, LongTensor, Size, Tensor
1620

1721

@@ -102,6 +106,9 @@ def _validate(self) -> None:
102106
f"`event shape` {self.event_shape}."
103107
)
104108

109+
def clone(self) -> DenseContainer:
110+
return dataclasses.replace(self)
111+
105112

106113
@dataclass(eq=False)
107114
class SliceContainer(BotorchContainer):
@@ -149,3 +156,10 @@ def _validate(self) -> None:
149156
f"Shapes of `values` {values.shape} and `indices` "
150157
f"{indices.shape} incompatible with `event_shape` {event_shape}."
151158
)
159+
160+
def clone(self) -> SliceContainer:
161+
return type(self)(
162+
values=self.values.clone(),
163+
indices=self.indices.clone(),
164+
event_shape=torch.Size(self.event_shape),
165+
)

botorch/utils/datasets.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from __future__ import annotations
1010

11+
import copy
12+
1113
from typing import Any
1214

1315
import torch
@@ -70,6 +72,7 @@ def __init__(
7072
self._Yvar = Yvar
7173
self.feature_names = feature_names
7274
self.outcome_names = outcome_names
75+
self.validate_init = validate_init
7376
if validate_init:
7477
self._validate()
7578

@@ -147,6 +150,52 @@ def __eq__(self, other: Any) -> bool:
147150
and self.outcome_names == other.outcome_names
148151
)
149152

153+
def clone(
154+
self, deepcopy: bool = False, mask: Tensor | None = None
155+
) -> SupervisedDataset:
156+
"""Return a copy of the dataset.
157+
158+
Args:
159+
deepcopy: If True, perform a deep copy. Otherwise, use the same
160+
tensors/lists.
161+
mask: A `n`-dim boolean mask indicating which rows to keep. This is used
162+
along the -2 dimension.
163+
164+
Returns:
165+
The new dataset.
166+
"""
167+
new_X = self._X
168+
new_Y = self._Y
169+
new_Yvar = self._Yvar
170+
feature_names = self.feature_names
171+
outcome_names = self.outcome_names
172+
if mask is not None:
173+
if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]):
174+
raise NotImplementedError(
175+
"Masking is not supported for BotorchContainers."
176+
)
177+
new_X = new_X[..., mask, :]
178+
new_Y = new_Y[..., mask, :]
179+
if new_Yvar is not None:
180+
new_Yvar = new_Yvar[..., mask, :]
181+
if deepcopy:
182+
new_X = new_X.clone()
183+
new_Y = new_Y.clone()
184+
new_Yvar = new_Yvar.clone() if new_Yvar is not None else None
185+
feature_names = copy.copy(self.feature_names)
186+
outcome_names = copy.copy(self.outcome_names)
187+
kwargs = {}
188+
if new_Yvar is not None:
189+
kwargs = {"Yvar": new_Yvar}
190+
return type(self)(
191+
X=new_X,
192+
Y=new_Y,
193+
feature_names=feature_names,
194+
outcome_names=outcome_names,
195+
validate_init=self.validate_init,
196+
**kwargs,
197+
)
198+
150199

151200
class RankingDataset(SupervisedDataset):
152201
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
@@ -339,7 +388,7 @@ def from_joint_dataset(
339388
outcome_names=[outcome_name],
340389
)
341390
datasets.append(new_dataset)
342-
# Return the new
391+
# Return the new dataset
343392
return cls(
344393
datasets=datasets,
345394
target_outcome_name=outcome_names_per_task.get(
@@ -466,6 +515,37 @@ def __eq__(self, other: Any) -> bool:
466515
and self.task_feature_index == other.task_feature_index
467516
)
468517

518+
def clone(
519+
self, deepcopy: bool = False, mask: Tensor | None = None
520+
) -> MultiTaskDataset:
521+
"""Return a copy of the dataset.
522+
523+
Args:
524+
deepcopy: If True, perform a deep copy. Otherwise, use the same
525+
tensors/lists/datasets.
526+
mask: A `n`-dim boolean mask indicating which rows to keep from the target
527+
dataset. This is used along the -2 dimension.
528+
529+
Returns:
530+
The new dataset.
531+
"""
532+
datasets = list(self.datasets.values())
533+
if mask is not None or deepcopy:
534+
new_datasets = []
535+
for outcome, ds in self.datasets.items():
536+
new_datasets.append(
537+
ds.clone(
538+
deepcopy=deepcopy,
539+
mask=mask if outcome == self.target_outcome_name else None,
540+
)
541+
)
542+
datasets = new_datasets
543+
return MultiTaskDataset(
544+
datasets=datasets,
545+
target_outcome_name=self.target_outcome_name,
546+
task_feature_index=self.task_feature_index,
547+
)
548+
469549

470550
class ContextualDataset(SupervisedDataset):
471551
"""This is a contextual dataset that is constructed from either a single
@@ -627,3 +707,33 @@ def _validate_decompositions(self) -> None:
627707
raise InputDataError(
628708
f"{outcome} is missing in metric_decomposition."
629709
)
710+
711+
def clone(
712+
self, deepcopy: bool = False, mask: Tensor | None = None
713+
) -> ContextualDataset:
714+
"""Return a copy of the dataset.
715+
716+
Args:
717+
deepcopy: If True, perform a deep copy. Otherwise, use the same
718+
tensors/lists/datasets.
719+
mask: A `n`-dim boolean mask indicating which rows to keep. This is used
720+
along the -2 dimension. `n` here corresponds to the number of rows in
721+
an individual dataset.
722+
723+
Returns:
724+
The new dataset.
725+
"""
726+
datasets = list(self.datasets.values())
727+
if mask is not None or deepcopy:
728+
datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets]
729+
if deepcopy:
730+
parameter_decomposition = copy.deepcopy(self.parameter_decomposition)
731+
metric_decomposition = copy.deepcopy(self.metric_decomposition)
732+
else:
733+
parameter_decomposition = self.parameter_decomposition
734+
metric_decomposition = self.metric_decomposition
735+
return ContextualDataset(
736+
datasets=datasets,
737+
parameter_decomposition=parameter_decomposition,
738+
metric_decomposition=metric_decomposition,
739+
)

test/utils/test_containers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def test_dense(self):
8484
# Test `__call__`
8585
self.assertTrue(X().equal(values))
8686

87+
# Test `clone`
88+
self.assertEqual(X.clone(), X)
89+
8790
def test_slice(self):
8891
for arity in (2, 4):
8992
for vals in (

0 commit comments

Comments
 (0)