Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 76 additions & 4 deletions botorch/test_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,52 @@ def validate_parameter_indices(
)


def validate_inputs(
X: Tensor,
dim: int,
bounds: Tensor,
discrete_inds: list[int],
categorical_inds: list[int],
) -> None:
r"""Check that the inputs are valid.

This method checks that the input tensor `X` has the correct shape, is within
the bounds, and that the discrete and categorical parameters are integer-valued.

Args:
X: A `(batch_shape) x n x d`-dim tensor of point(s) at which to evaluate
dim: Number of search space dimensions.
bounds: A `2 x d`-dim tensor of lower and upper bounds.
discrete_inds: List of unique integers corresponding to discrete parameters.
categorical_inds: List of unique integers corresponding to categorical
parameters.

Raises:
ValueError: If the parameter indices are invalid.
"""

if not X.shape[-1] == dim:
raise ValueError(
"Expected `X` to have shape `(batch_shape) x n x d`. "
f"Got {X.shape=} and {dim=}"
)
if not ((X >= bounds[0]).all() and (X <= bounds[1]).all()):
raise ValueError("Expected `X` to be within the bounds of the test problem.")
for inds in [discrete_inds, categorical_inds]:
if not (X[..., inds] == X[..., inds].round()).all():
raise ValueError(
"Expected `X` to have integer values for the discrete and "
"categorical parameters."
)


class BaseTestProblem(Module, ABC):
r"""Base class for test functions."""

dim: int
_bounds: list[tuple[float, float]]
_bounds: list[
tuple[float, float]
] # Bounds, must be integers for discrete/categorical parameters
_check_grad_at_opt: bool = True
continuous_inds: list[int] = [] # Float-valued range parameters (bounds inclusive)
discrete_inds: list[int] = [] # Ordered integer parameters (bounds inclusive)
Expand Down Expand Up @@ -139,11 +180,26 @@ def forward(self, X: Tensor, noise: bool = True) -> Tensor:
f = -f
return f

@abstractmethod
def evaluate_true(self, X: Tensor) -> Tensor:
r"""
Evaluate the function (w/o observation noise) on a set of points.

Args:
X: A `(batch_shape) x d`-dim tensor of point(s) at which to
evaluate.

Returns:
A `batch_shape`-dim tensor.
"""
validate_inputs(
X, self.dim, self.bounds, self.discrete_inds, self.categorical_inds
)
return self._evaluate_true(X=X)

@abstractmethod
def _evaluate_true(self, X: Tensor) -> Tensor:
r"""Evaluate the function (w/o observation noise) on a set of points.

Args:
X: A `(batch_shape) x d`-dim tensor of point(s) at which to
evaluate.
Expand Down Expand Up @@ -206,7 +262,6 @@ def is_feasible(self, X: Tensor, noise: bool = True) -> Tensor:
"""
return (self.evaluate_slack(X=X, noise=noise) >= 0.0).all(dim=-1)

@abstractmethod
def evaluate_slack_true(self, X: Tensor) -> Tensor:
r"""Evaluate the constraint slack (w/o observation noise) on a set of points.

Expand All @@ -218,6 +273,23 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
A `batch_shape x n_c`-dim tensor of constraint slack (where positive slack
corresponds to the constraint being feasible).
"""
validate_inputs(
X, self.dim, self.bounds, self.discrete_inds, self.categorical_inds
)
return self._evaluate_slack_true(X=X)

@abstractmethod
def _evaluate_slack_true(self, X: Tensor) -> Tensor:
r"""Evaluate the constraint slack (w/o observation noise) on a set of points.

Args:
X: A `batch_shape x d`-dim tensor of point(s) at which to evaluate the
constraint slacks: `c_1(X), ...., c_{n_c}(X)`.

Returns:
A `batch_shape x n_c`-dim tensor of constraint slack (where positive slack
corresponds
"""
pass # pragma: no cover


Expand Down Expand Up @@ -372,7 +444,7 @@ def __init__(
self._current_seed: int | None = None
self._seeds: Iterator[int] | None = None if seeds is None else iter(seeds)

def evaluate_true(self, X: Tensor) -> Tensor:
def _evaluate_true(self, X: Tensor) -> Tensor:
return self.base_test_problem.evaluate_true(X)

def forward(self, X: Tensor, noise: bool = True) -> Tensor:
Expand Down
10 changes: 5 additions & 5 deletions botorch/test_functions/multi_fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AugmentedBranin(SyntheticTestFunction):
(math.pi, 2.1763039559891064, 0.9),
]

def evaluate_true(self, X: Tensor) -> Tensor:
def _evaluate_true(self, X: Tensor) -> Tensor:
t1 = (
X[..., 1]
- (5.1 / (4 * math.pi**2) - 0.1 * (1 - X[..., 2])) * X[..., 0].pow(2)
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
self.register_buffer("A", torch.tensor(A))
self.register_buffer("P", torch.tensor(P))

def evaluate_true(self, X: Tensor) -> Tensor:
def _evaluate_true(self, X: Tensor) -> Tensor:
self.to(device=X.device, dtype=X.dtype)
inner_sum = torch.sum(
self.A * (X[..., :6].unsqueeze(-2) - 0.0001 * self.P).pow(2), dim=-1
Expand Down Expand Up @@ -164,7 +164,7 @@ def __init__(
self._optimizers = [tuple(1.0 for _ in range(self.dim))]
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)

def evaluate_true(self, X: Tensor) -> Tensor:
def _evaluate_true(self, X: Tensor) -> Tensor:
X_curr = X[..., :-3]
X_next = X[..., 1:-2]
t1 = 100 * (X_next - X_curr.pow(2) + 0.1 * (1 - X[..., -2:-1])).pow(2)
Expand Down Expand Up @@ -215,7 +215,7 @@ class WingWeightMultiFidelity(SyntheticTestFunction):
fidelities = [0, 1, 2, 3]
_optimal_value = 123.25

def evaluate_true(self, X: torch.Tensor) -> Tensor:
def _evaluate_true(self, X: torch.Tensor) -> Tensor:
s_w, w_fw, A, Lambda_deg, q, lam, t_c, N_z, w_dg, w_pp, fidelity = X.unbind(
dim=-1
)
Expand Down Expand Up @@ -304,7 +304,7 @@ class BoreholeMultiFidelity(SyntheticTestFunction):
fidelities = [0, 1, 2, 3, 4]
_optimal_value = 3.98

def evaluate_true(self, X: torch.Tensor) -> torch.Tensor:
def _evaluate_true(self, X: torch.Tensor) -> torch.Tensor:
r_w, r, T_u, T_l, H_u, H_l, L, K_w, fidelity = X.unbind(dim=-1)
LTu = L * T_u
two_pi_T_u = 2.0 * math.pi * T_u
Expand Down
Loading