|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
| 11 | +import copy |
| 12 | + |
11 | 13 | from typing import Any |
12 | 14 |
|
13 | 15 | import torch |
@@ -70,6 +72,7 @@ def __init__( |
70 | 72 | self._Yvar = Yvar |
71 | 73 | self.feature_names = feature_names |
72 | 74 | self.outcome_names = outcome_names |
| 75 | + self.validate_init = validate_init |
73 | 76 | if validate_init: |
74 | 77 | self._validate() |
75 | 78 |
|
@@ -147,6 +150,52 @@ def __eq__(self, other: Any) -> bool: |
147 | 150 | and self.outcome_names == other.outcome_names |
148 | 151 | ) |
149 | 152 |
|
| 153 | + def clone( |
| 154 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 155 | + ) -> SupervisedDataset: |
| 156 | + """Return a copy of the dataset. |
| 157 | +
|
| 158 | + Args: |
| 159 | + deepcopy: If True, perform a deep copy. Otherwise, use the same |
| 160 | + tensors/lists. |
| 161 | + mask: A `n`-dim boolean mask indicating which rows to keep. This is used |
| 162 | + along the -2 dimension. |
| 163 | +
|
| 164 | + Returns: |
| 165 | + The new dataset. |
| 166 | + """ |
| 167 | + new_X = self._X |
| 168 | + new_Y = self._Y |
| 169 | + new_Yvar = self._Yvar |
| 170 | + feature_names = self.feature_names |
| 171 | + outcome_names = self.outcome_names |
| 172 | + if mask is not None: |
| 173 | + if any(isinstance(x, BotorchContainer) for x in [new_X, new_Y, new_Yvar]): |
| 174 | + raise NotImplementedError( |
| 175 | + "Masking is not supported for BotorchContainers." |
| 176 | + ) |
| 177 | + new_X = new_X[..., mask, :] |
| 178 | + new_Y = new_Y[..., mask, :] |
| 179 | + if new_Yvar is not None: |
| 180 | + new_Yvar = new_Yvar[..., mask, :] |
| 181 | + if deepcopy: |
| 182 | + new_X = new_X.clone() |
| 183 | + new_Y = new_Y.clone() |
| 184 | + new_Yvar = new_Yvar.clone() if new_Yvar is not None else None |
| 185 | + feature_names = copy.copy(self.feature_names) |
| 186 | + outcome_names = copy.copy(self.outcome_names) |
| 187 | + kwargs = {} |
| 188 | + if new_Yvar is not None: |
| 189 | + kwargs = {"Yvar": new_Yvar} |
| 190 | + return type(self)( |
| 191 | + X=new_X, |
| 192 | + Y=new_Y, |
| 193 | + feature_names=feature_names, |
| 194 | + outcome_names=outcome_names, |
| 195 | + validate_init=self.validate_init, |
| 196 | + **kwargs, |
| 197 | + ) |
| 198 | + |
150 | 199 |
|
151 | 200 | class RankingDataset(SupervisedDataset): |
152 | 201 | r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations |
@@ -339,7 +388,7 @@ def from_joint_dataset( |
339 | 388 | outcome_names=[outcome_name], |
340 | 389 | ) |
341 | 390 | datasets.append(new_dataset) |
342 | | - # Return the new |
| 391 | + # Return the new dataset |
343 | 392 | return cls( |
344 | 393 | datasets=datasets, |
345 | 394 | target_outcome_name=outcome_names_per_task.get( |
@@ -466,6 +515,37 @@ def __eq__(self, other: Any) -> bool: |
466 | 515 | and self.task_feature_index == other.task_feature_index |
467 | 516 | ) |
468 | 517 |
|
| 518 | + def clone( |
| 519 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 520 | + ) -> MultiTaskDataset: |
| 521 | + """Return a copy of the dataset. |
| 522 | +
|
| 523 | + Args: |
| 524 | + deepcopy: If True, perform a deep copy. Otherwise, use the same |
| 525 | + tensors/lists/datasets. |
| 526 | + mask: A `n`-dim boolean mask indicating which rows to keep from the target |
| 527 | + dataset. This is used along the -2 dimension. |
| 528 | +
|
| 529 | + Returns: |
| 530 | + The new dataset. |
| 531 | + """ |
| 532 | + datasets = list(self.datasets.values()) |
| 533 | + if mask is not None or deepcopy: |
| 534 | + new_datasets = [] |
| 535 | + for outcome, ds in self.datasets.items(): |
| 536 | + new_datasets.append( |
| 537 | + ds.clone( |
| 538 | + deepcopy=deepcopy, |
| 539 | + mask=mask if outcome == self.target_outcome_name else None, |
| 540 | + ) |
| 541 | + ) |
| 542 | + datasets = new_datasets |
| 543 | + return MultiTaskDataset( |
| 544 | + datasets=datasets, |
| 545 | + target_outcome_name=self.target_outcome_name, |
| 546 | + task_feature_index=self.task_feature_index, |
| 547 | + ) |
| 548 | + |
469 | 549 |
|
470 | 550 | class ContextualDataset(SupervisedDataset): |
471 | 551 | """This is a contextual dataset that is constructed from either a single |
@@ -627,3 +707,33 @@ def _validate_decompositions(self) -> None: |
627 | 707 | raise InputDataError( |
628 | 708 | f"{outcome} is missing in metric_decomposition." |
629 | 709 | ) |
| 710 | + |
| 711 | + def clone( |
| 712 | + self, deepcopy: bool = False, mask: Tensor | None = None |
| 713 | + ) -> ContextualDataset: |
| 714 | + """Return a copy of the dataset. |
| 715 | +
|
| 716 | + Args: |
| 717 | + deepcopy: If True, perform a deep copy. Otherwise, use the same |
| 718 | + tensors/lists/datasets. |
| 719 | + mask: A `n`-dim boolean mask indicating which rows to keep. This is used |
| 720 | + along the -2 dimension. `n` here corresponds to the number of rows in |
| 721 | + an individual dataset. |
| 722 | +
|
| 723 | + Returns: |
| 724 | + The new dataset. |
| 725 | + """ |
| 726 | + datasets = list(self.datasets.values()) |
| 727 | + if mask is not None or deepcopy: |
| 728 | + datasets = [ds.clone(deepcopy=deepcopy, mask=mask) for ds in datasets] |
| 729 | + if deepcopy: |
| 730 | + parameter_decomposition = copy.deepcopy(self.parameter_decomposition) |
| 731 | + metric_decomposition = copy.deepcopy(self.metric_decomposition) |
| 732 | + else: |
| 733 | + parameter_decomposition = self.parameter_decomposition |
| 734 | + metric_decomposition = self.metric_decomposition |
| 735 | + return ContextualDataset( |
| 736 | + datasets=datasets, |
| 737 | + parameter_decomposition=parameter_decomposition, |
| 738 | + metric_decomposition=metric_decomposition, |
| 739 | + ) |
0 commit comments