Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __init__(
OutcomeTransform.__init__(self)
self._stratification_idx = stratification_idx
task_values = task_values.unique(sorted=True)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.double)
if self.strata_mapping is None:
self.strata_mapping = task_values
n_strata = self.strata_mapping.shape[0]
Expand Down Expand Up @@ -576,7 +576,7 @@ def forward(
strata = X[..., self._stratification_idx].long()
unique_strata = strata.unique()
for s in unique_strata:
mapped_strata = self.strata_mapping[s]
mapped_strata = self.strata_mapping[s].long()
mask = strata != s
Y_strata = Y.clone()
Y_strata[..., mask, :] = float("nan")
Expand Down Expand Up @@ -616,7 +616,7 @@ def _get_per_input_means_stdvs(
- The per-input stdvs squared.
"""
strata = X[..., self._stratification_idx].long()
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
mapped_strata = self.strata_mapping[strata].unsqueeze(-1).long()
# get means and stdvs for each strata
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
Expand Down
2 changes: 2 additions & 0 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor
return value will be `None`, when the task values are contiguous
integers starting from zero.
"""
if dtype not in (torch.float, torch.double):
raise ValueError(f"dtype must be torch.float or torch.double, but got {dtype}.")
task_range = torch.arange(
len(task_values), dtype=task_values.dtype, device=task_values.device
)
Expand Down
9 changes: 9 additions & 0 deletions test/models/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,12 @@ def test_get_task_value_remapping(self) -> None:
mapping = get_task_value_remapping(task_values, dtype)
self.assertTrue(torch.equal(mapping[[1, 3]], expected_mapping_no_nan))
self.assertTrue(torch.isnan(mapping[[0, 2]]).all())

def test_get_task_value_remapping_invalid_dtype(self) -> None:
task_values = torch.tensor([1, 3])
for dtype in (torch.int32, torch.long, torch.bool):
with self.assertRaisesRegex(
ValueError,
f"dtype must be torch.float or torch.double, but got {dtype}.",
):
get_task_value_remapping(task_values, dtype)
22 changes: 16 additions & 6 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,24 @@ def test_stratified_standardize(self):
n = 5
seed = randint(0, 100)
torch.manual_seed(seed)
for dtype, batch_shape in itertools.product(
(torch.float, torch.double), (torch.Size([]), torch.Size([3]))
for dtype, batch_shape, task_values in itertools.product(
(torch.float, torch.double),
(torch.Size([]), torch.Size([3])),
(
torch.tensor([0, 1], dtype=torch.long, device=self.device),
torch.tensor([0, 3], dtype=torch.long, device=self.device),
),
):
torch.manual_seed(seed)
tval = task_values[1].item()
X = torch.rand(*batch_shape, n, 2, dtype=dtype, device=self.device)
X[..., -1] = torch.tensor([0, 1, 0, 1, 0], dtype=dtype, device=self.device)
X[..., -1] = torch.tensor(
[0, tval, 0, tval, 0], dtype=dtype, device=self.device
)
Y = torch.randn(*batch_shape, n, 1, dtype=dtype, device=self.device)
Yvar = torch.rand(*batch_shape, n, 1, dtype=dtype, device=self.device)
strata_tf = StratifiedStandardize(
task_values=torch.tensor([0, 1], dtype=torch.long, device=self.device),
task_values=task_values,
stratification_idx=-1,
batch_shape=batch_shape,
)
Expand All @@ -400,9 +408,11 @@ def test_stratified_standardize(self):
tf_Y1, tf_Yvar1 = tf1(Y=Y1, Yvar=Yvar1, X=X1)
# check that stratified means are expected
self.assertAllClose(strata_tf.means[..., :1, :], tf0.means)
self.assertAllClose(strata_tf.means[..., 1:, :], tf1.means)
# use remapped task values to index
self.assertAllClose(strata_tf.means[..., 1:2, :], tf1.means)
self.assertAllClose(strata_tf.stdvs[..., :1, :], tf0.stdvs)
self.assertAllClose(strata_tf.stdvs[..., 1:, :], tf1.stdvs)
# use remapped task values to index
self.assertAllClose(strata_tf.stdvs[..., 1:2, :], tf1.stdvs)
# check the transformed values
self.assertAllClose(tf_Y0, tf_Y[mask0].view(*batch_shape, -1, 1))
self.assertAllClose(tf_Y1, tf_Y[mask1].view(*batch_shape, -1, 1))
Expand Down