Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions ax/utils/sensitivity/derivative_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pyre-strict

from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import partial
from typing import Any
Expand Down Expand Up @@ -90,16 +90,12 @@ def __init__(
this list are generated using an integer-valued uniform distribution,
rather than the default (pseudo-)random continuous uniform distribution.
"""
# pyre-fixme[4]: Attribute must be annotated.
self.dim = assert_is_instance(model.train_inputs, tuple)[0].shape[-1]
self.dim: int = assert_is_instance(model.train_inputs, tuple)[0].shape[-1]
self.derivative_gp = derivative_gp
self.kernel_type = kernel_type
# pyre-fixme[4]: Attribute must be annotated.
self.bootstrap = num_bootstrap_samples > 1
# pyre-fixme[4]: Attribute must be annotated.
self.num_bootstrap_samples = (
num_bootstrap_samples - 1
) # deduct 1 because the first is meant to be the full grid
self.bootstrap: bool = num_bootstrap_samples > 1
# deduct 1 because the first is meant to be the full grid
self.num_bootstrap_samples: int = num_bootstrap_samples - 1
self.torch_device: torch.device = bounds.device
if self.derivative_gp and (self.kernel_type is None):
raise ValueError("Kernel type has to be specified to use derivative GP")
Expand Down Expand Up @@ -417,7 +413,7 @@ def aggregation(


def compute_derivatives_from_model_list(
model_list: list[Model],
model_list: Sequence[Model],
bounds: torch.Tensor,
discrete_features: list[int] | None = None,
**kwargs: Any,
Expand Down
43 changes: 20 additions & 23 deletions ax/utils/sensitivity/sobol_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
compute_derivatives_from_model_list,
sample_discrete_parameters,
)
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.sampling import draw_sobol_samples
from botorch.utils.transforms import is_ensemble, unnormalize
Expand Down Expand Up @@ -444,7 +444,7 @@ def ProbitLinkMean(mean: torch.Tensor, var: torch.Tensor) -> torch.Tensor:
class SobolSensitivityGPMean:
def __init__(
self,
model: Model, # TODO: narrow type down. E.g. ModelListGP does not work.
model: GPyTorchModel,
bounds: torch.Tensor,
num_mc_samples: int = 10**4,
second_order: bool = False,
Expand All @@ -461,7 +461,7 @@ def __init__(
first order indices, total indices and second order indices (if specified ).

Args:
model: Botorch model
model: BoTorch model whose posterior is a `GPyTorchPosterior`.
bounds: `2 x d` parameter bounds over which to evaluate model sensitivity.
method: if "predictive mean", the predictive mean is used for indices
computation. If "GP samples", posterior sampling is used instead.
Expand All @@ -484,28 +484,25 @@ def __init__(
self.model = model
self.second_order = second_order
self.input_qmc = input_qmc
# pyre-fixme[4]: Attribute must be annotated.
self.bootstrap = num_bootstrap_samples > 1
self.bootstrap: bool = num_bootstrap_samples > 1
self.num_bootstrap_samples = num_bootstrap_samples
self.num_mc_samples = num_mc_samples

def input_function(x: Tensor) -> Tensor:
with torch.no_grad():
means, variances = [], []
# Since we're only looking at mean & variance, we can freely
# use mini-batches.
for x_split in x.split(split_size=mini_batch_size):
p = assert_is_instance(
self.model.posterior(x_split),
GPyTorchPosterior,
)
means.append(p.mean)
variances.append(p.variance)

cat_dim = 1 if is_ensemble(self.model) else 0
return link_function(
torch.cat(means, dim=cat_dim), torch.cat(variances, dim=cat_dim)
)
# We only need variances, not covariances, so we use the batch
# dimension, turning x from (*batch_dim, n, d) to
# (*batch_dim, n, 1, d)
p = self.model.posterior(x.unsqueeze(-2))
mean = p.mean.squeeze(-2)
variance = p.variance.squeeze(-2)
if is_ensemble(self.model):
# If x has shape [n, d],
# the mean will have shape [n, s, m], where 's' is the ensemble
# size. Reshape to [s, n, m]
mean = torch.swapaxes(mean, -2, -3)
variance = torch.swapaxes(variance, -2, -3)
return link_function(mean, variance)

self.sensitivity = SobolSensitivity(
bounds=bounds,
Expand Down Expand Up @@ -796,7 +793,7 @@ def second_order_indices(self) -> Tensor:


def compute_sobol_indices_from_model_list(
model_list: list[Model],
model_list: list[GPyTorchModel],
bounds: Tensor,
order: str = "first",
discrete_features: list[int] | None = None,
Expand Down Expand Up @@ -974,7 +971,7 @@ def _get_generator_and_digest(

def _get_model_per_metric(
generator: LegacyBoTorchGenerator | ModularBoTorchGenerator, metrics: list[str]
) -> list[Model]:
) -> list[GPyTorchModel]:
"""For a given TorchGenerator model, returns a list of botorch.models.model.Model
objects corresponding to - and in the same order as - the given metrics.
"""
Expand All @@ -984,7 +981,7 @@ def _get_model_per_metric(
model_idx = [generator.metric_names.index(m) for m in metrics]
if not isinstance(gp_model, ModelList):
if gp_model.num_outputs == 1: # can accept single output models
return [gp_model for _ in model_idx]
return [assert_is_instance(gp_model, GPyTorchModel) for _ in model_idx]
raise NotImplementedError(
f"type(adapter.generator.model) = {type(gp_model)}, "
"but only ModelList is supported."
Expand Down