Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions botorch/acquisition/multi_objective/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
UnstandardizeMCMultiOutputObjective,
WeightedMCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.utils import (
Expand All @@ -47,6 +46,5 @@
"MCMultiOutputObjective",
"MultiObjectiveAnalyticAcquisitionFunction",
"MultiObjectiveMCAcquisitionFunction",
"UnstandardizeMCMultiOutputObjective",
"WeightedMCMultiOutputObjective",
]
46 changes: 0 additions & 46 deletions botorch/acquisition/multi_objective/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,49 +207,3 @@ def apply_feasibility_weights(

def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
return self.objective(self.apply_feasibility_weights(samples), X=X)


class UnstandardizeMCMultiOutputObjective(IdentityMCMultiOutputObjective):
r"""Objective that unstandardizes the samples.

TODO: remove this when MultiTask models support outcome transforms.

Example:
>>> unstd_objective = UnstandardizeMCMultiOutputObjective(Y_mean, Y_std)
>>> samples = sampler(posterior)
>>> objective = unstd_objective(samples)
"""

def __init__(
self, Y_mean: Tensor, Y_std: Tensor, outcomes: Optional[List[int]] = None
) -> None:
r"""Initialize objective.

Args:
Y_mean: `m`-dim tensor of outcome means.
Y_std: `m`-dim tensor of outcome standard deviations.
outcomes: A list of `m' <= m` indices that specifies which of the `m` model
outputs should be considered as the outcomes for MOO. If omitted, use
all model outcomes. Typically used for constrained optimization.
"""
if Y_mean.ndim > 1 or Y_std.ndim > 1:
raise BotorchTensorDimensionError(
"Y_mean and Y_std must both be 1-dimensional, but got "
f"{Y_mean.ndim} and {Y_std.ndim}"
)
elif outcomes is not None and len(outcomes) > Y_mean.shape[-1]:
raise BotorchTensorDimensionError(
f"Cannot specify more ({len(outcomes)}) outcomes than are present in "
f"the normalization inputs ({Y_mean.shape[-1]})."
)
super().__init__(outcomes=outcomes, num_outcomes=Y_mean.shape[-1])
if outcomes is not None:
Y_mean = Y_mean.index_select(-1, self.outcomes.to(Y_mean.device))
Y_std = Y_std.index_select(-1, self.outcomes.to(Y_mean.device))

self.register_buffer("Y_mean", Y_mean)
self.register_buffer("Y_std", Y_std)

def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
samples = super().forward(samples=samples)
return samples * self.Y_std + self.Y_mean
42 changes: 1 addition & 41 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from typing import Callable, List, Optional, TYPE_CHECKING, Union

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
from botorch.sampling import IIDNormalSampler
from botorch.utils import apply_constraints
Expand Down Expand Up @@ -234,45 +233,6 @@ def forward(self, posterior: GPyTorchPosterior) -> GPyTorchPosterior:
return GPyTorchPosterior(distribution=new_mvn)


class UnstandardizePosteriorTransform(PosteriorTransform):
r"""Posterior transform that unstandardizes the posterior.

TODO: remove this when MultiTask models support outcome transforms.

Example:
>>> unstd_transform = UnstandardizePosteriorTransform(Y_mean, Y_std)
>>> unstd_posterior = unstd_transform(posterior)
"""

def __init__(self, Y_mean: Tensor, Y_std: Tensor) -> None:
r"""Initialize objective.

Args:
Y_mean: `m`-dim tensor of outcome means
Y_std: `m`-dim tensor of outcome standard deviations

"""
if Y_mean.ndim > 1 or Y_std.ndim > 1:
raise BotorchTensorDimensionError(
"Y_mean and Y_std must both be 1-dimensional, but got "
f"{Y_mean.ndim} and {Y_std.ndim}"
)
super().__init__()
self.outcome_transform = Standardize(m=Y_mean.shape[0]).to(Y_mean)
Y_std_unsqueezed = Y_std.unsqueeze(0)
self.outcome_transform.means = Y_mean.unsqueeze(0)
self.outcome_transform.stdvs = Y_std_unsqueezed
self.outcome_transform._stdvs_sq = Y_std_unsqueezed.pow(2)
self.outcome_transform._is_trained = torch.tensor(True)
self.outcome_transform.eval()

def evaluate(self, Y: Tensor) -> Tensor:
return self.outcome_transform.untransform(Y)[0]

def forward(self, posterior: GPyTorchPosterior) -> Tensor:
return self.outcome_transform.untransform_posterior(posterior)


class MCAcquisitionObjective(Module, ABC):
r"""Abstract base class for MC-based objectives.

Expand Down
82 changes: 5 additions & 77 deletions test/acquisition/multi_objective/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
FeasibilityWeightedMCMultiOutputObjective,
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
UnstandardizeMCMultiOutputObjective,
WeightedMCMultiOutputObjective,
)
from botorch.acquisition.objective import (
IdentityMCObjective,
UnstandardizePosteriorTransform,
)
from botorch.acquisition.objective import IdentityMCObjective
from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError
from botorch.models.transforms.outcome import Standardize
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


Expand All @@ -37,14 +32,17 @@ def test_identity_mc_multi_output_objective(self):
objective = IdentityMCMultiOutputObjective()
with self.assertRaises(BotorchTensorDimensionError):
IdentityMCMultiOutputObjective(outcomes=[0])
# test negative outcome without specifying num_outcomes
# Test negative outcome without specifying num_outcomes.
with self.assertRaises(BotorchError):
IdentityMCMultiOutputObjective(outcomes=[0, -1])
for batch_shape, m, dtype in itertools.product(
([], [3]), (2, 3), (torch.float, torch.double)
):
samples = torch.rand(*batch_shape, 2, m, device=self.device, dtype=dtype)
self.assertTrue(torch.equal(objective(samples), samples))
# Test negative outcome with num_outcomes.
objective = IdentityMCMultiOutputObjective(outcomes=[0, -1], num_outcomes=3)
self.assertEqual(objective.outcomes.tolist(), [0, 2])


class TestWeightedMCMultiOutputObjective(BotorchTestCase):
Expand Down Expand Up @@ -138,73 +136,3 @@ def test_feasibility_weighted_mc_multi_output_objective(self):
X_baseline=X,
constraint_idcs=[1, -1],
)


class TestUnstandardizeMultiOutputObjective(BotorchTestCase):
def test_unstandardize_mo_objective(self):
Y_mean = torch.ones(2)
Y_std = torch.ones(2)
with self.assertRaises(BotorchTensorDimensionError):
UnstandardizeMCMultiOutputObjective(
Y_mean=Y_mean, Y_std=Y_std, outcomes=[0, 1, 2]
)
for objective_class in (
UnstandardizeMCMultiOutputObjective,
UnstandardizePosteriorTransform,
):
with self.assertRaises(BotorchTensorDimensionError):
objective_class(Y_mean=Y_mean.unsqueeze(0), Y_std=Y_std)
with self.assertRaises(BotorchTensorDimensionError):
objective_class(Y_mean=Y_mean, Y_std=Y_std.unsqueeze(0))
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std)
for batch_shape, m, outcomes, dtype in itertools.product(
([], [3]), (2, 3), (None, [-2, -1]), (torch.float, torch.double)
):
Y_mean = torch.rand(m, dtype=dtype, device=self.device)
Y_std = torch.rand(m, dtype=dtype, device=self.device).clamp_min(1e-3)
kwargs = {}
if objective_class == UnstandardizeMCMultiOutputObjective:
kwargs["outcomes"] = outcomes
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std, **kwargs)
if objective_class == UnstandardizePosteriorTransform:
objective = objective_class(Y_mean=Y_mean, Y_std=Y_std)
if outcomes is None:
# passing outcomes is not currently supported
mean = torch.rand(2, m, dtype=dtype, device=self.device)
variance = variance = torch.rand(
2, m, dtype=dtype, device=self.device
)
mock_posterior = MockPosterior(mean=mean, variance=variance)
tf_posterior = objective(mock_posterior)
tf = Standardize(m=m)
tf.means = Y_mean
tf.stdvs = Y_std
tf._stdvs_sq = Y_std.pow(2)
tf._is_trained = torch.tensor(True)
tf.eval()
expected_posterior = tf.untransform_posterior(mock_posterior)
self.assertTrue(
torch.equal(tf_posterior.mean, expected_posterior.mean)
)
self.assertTrue(
torch.equal(
tf_posterior.variance, expected_posterior.variance
)
)
# testing evaluate specifically
if objective_class == UnstandardizePosteriorTransform:
Y = torch.randn_like(Y_mean) + Y_mean
val = objective.evaluate(Y)
val_expected = Y_mean + Y * Y_std
self.assertTrue(torch.allclose(val, val_expected))
else:

samples = torch.rand(
*batch_shape, 2, m, dtype=dtype, device=self.device
)
obj_expected = samples * Y_std.to(dtype=dtype) + Y_mean.to(
dtype=dtype
)
if outcomes is not None:
obj_expected = obj_expected[..., outcomes]
self.assertTrue(torch.equal(objective(samples), obj_expected))
16 changes: 4 additions & 12 deletions test/acquisition/multi_objective/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
UnstandardizeMCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.utils import (
compute_sample_box_decomposition,
Expand Down Expand Up @@ -92,16 +92,8 @@ def test_prune_inferior_points_multi_objective(self):
X_pruned = prune_inferior_points_multi_objective(
model=mm, X=X, ref_point=ref_point
)
self.assertTrue(torch.equal(X_pruned, X[[-1]]))
# test unstd objective
unstd_obj = UnstandardizeMCMultiOutputObjective(
Y_mean=samples.mean(dim=0), Y_std=samples.std(dim=0), outcomes=[0, 1]
)
X_pruned = prune_inferior_points_multi_objective(
model=mm, X=X, ref_point=ref_point, objective=unstd_obj
)
self.assertTrue(torch.equal(X_pruned, X[[-1]]))
# test constraints
objective = IdentityMCMultiOutputObjective(outcomes=[0, 1])
samples_constrained = torch.tensor(
[[1.0, 2.0, -1.0], [2.0, 1.0, -1.0], [3.0, 4.0, 1.0]], **tkwargs
)
Expand All @@ -110,7 +102,7 @@ def test_prune_inferior_points_multi_objective(self):
model=mm_constrained,
X=X,
ref_point=ref_point,
objective=unstd_obj,
objective=objective,
constraints=[lambda Y: Y[..., -1]],
)
self.assertTrue(torch.equal(X_pruned, X[:2]))
Expand Down Expand Up @@ -161,7 +153,7 @@ def test_prune_inferior_points_multi_objective(self):
model=mm,
X=X,
ref_point=ref_point,
objective=unstd_obj,
objective=objective,
constraints=[lambda Y: Y[..., -1] - 3.0],
marginalize_dim=-3,
)
Expand Down