Skip to content

Commit e220777

Browse files
sdaultonfacebook-github-bot
authored andcommitted
StratifiedStandardize OutcomeTransform (#2671)
Summary: see title. This allows applying stratified standardization at the model level, which will enable selecting whether to use a Single-task or multi-task model in Ax while using the appropriate transform. I.e. One could specify ModelConfigs that use 1) `SingleTaskGP` + `Standardize`, 2) `MultiTaskGP` + `StratifiedStandardize`. Differential Revision: D67728920
1 parent 6026c6f commit e220777

File tree

5 files changed

+372
-43
lines changed

5 files changed

+372
-43
lines changed

botorch/models/multitask.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from botorch.models.model import FantasizeMixin
4040
from botorch.models.transforms.input import InputTransform
4141
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
42+
from botorch.models.utils.assorted import get_task_value_remapping
4243
from botorch.models.utils.gpytorch_modules import (
4344
get_covar_module_with_dim_scaled_prior,
4445
get_gaussian_likelihood_with_lognormal_prior,
@@ -82,40 +83,6 @@
8283
from torch import Tensor
8384

8485

85-
def get_task_value_remapping(task_values: Tensor, dtype: torch.dtype) -> Tensor | None:
86-
"""Construct an mapping of discrete task values to contiguous int-valued floats.
87-
88-
Args:
89-
task_values: A sorted long-valued tensor of task values.
90-
dtype: The dtype of the model inputs (e.g. `X`), which the new
91-
task values should have mapped to (e.g. float, double).
92-
93-
Returns:
94-
A tensor of shape `task_values.max() + 1` that maps task values
95-
to new task values. The indexing operation `mapper[task_value]`
96-
will produce a tensor of new task values, of the same shape as
97-
the original. The elements of the `mapper` tensor that do not
98-
appear in the original `task_values` are mapped to `nan`. The
99-
return value will be `None`, when the task values are contiguous
100-
integers starting from zero.
101-
"""
102-
task_range = torch.arange(
103-
len(task_values), dtype=task_values.dtype, device=task_values.device
104-
)
105-
mapper = None
106-
if not torch.equal(task_values, task_range):
107-
# Create a tensor that maps task values to new task values.
108-
# The number of tasks should be small, so this should be quite efficient.
109-
mapper = torch.full(
110-
(int(task_values.max().item()) + 1,),
111-
float("nan"),
112-
dtype=dtype,
113-
device=task_values.device,
114-
)
115-
mapper[task_values] = task_range.to(dtype=dtype)
116-
return mapper
117-
118-
11986
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
12087
r"""Multi-Task exact GP model using an ICM (intrinsic co-regionalization model)
12188
kernel. See [Bonilla2007MTGP]_ and [Swersky2013MTBO]_ for a reference on the

botorch/models/transforms/outcome.py

Lines changed: 244 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828
import torch
2929
from botorch.models.transforms.utils import (
30+
nanstd,
3031
norm_to_lognorm_mean,
3132
norm_to_lognorm_variance,
3233
)
34+
from botorch.models.utils.assorted import get_task_value_remapping
3335
from botorch.posteriors import GPyTorchPosterior, Posterior, TransformedPosterior
3436
from botorch.utils.transforms import normalize_indices
3537
from linear_operator.operators import CholLinearOperator, DiagLinearOperator
@@ -259,6 +261,25 @@ def __init__(
259261
self._batch_shape = batch_shape
260262
self._min_stdv = min_stdv
261263

264+
def _get_per_input_means_stdvs(
265+
self, X: Tensor, include_stdvs_sq: bool
266+
) -> tuple[Tensor, Tensor, Tensor | None]:
267+
r"""Get per-input means and stdvs.
268+
269+
Args:
270+
X: A `batch_shape x n x d`-dim tensor of input parameters.
271+
include_stdvs_sq: Whether to include the stdvs squared.
272+
This parameter is not used by this method
273+
274+
Returns:
275+
A three-tuple with the means and stdvs:
276+
277+
- The per-input means.
278+
- The per-input stdvs.
279+
- The per-input stdvs squared.
280+
"""
281+
return self.means, self.stdvs, self._stdvs_sq
282+
262283
def forward(
263284
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
264285
) -> tuple[Tensor, Tensor | None]:
@@ -313,9 +334,12 @@ def forward(
313334
self.stdvs = stdvs
314335
self._stdvs_sq = stdvs.pow(2)
315336
self._is_trained = torch.tensor(True)
316-
317-
Y_tf = (Y - self.means) / self.stdvs
318-
Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None
337+
include_stdvs_sq = Yvar is not None
338+
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
339+
X=X, include_stdvs_sq=include_stdvs_sq
340+
)
341+
Y_tf = (Y - means) / stdvs
342+
Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
319343
return Y_tf, Yvar_tf
320344

321345
def subset_output(self, idcs: list[int]) -> OutcomeTransform:
@@ -376,9 +400,12 @@ def untransform(
376400
"(e.g. `transform(Y)`) before calling `untransform`, since "
377401
"means and standard deviations need to be computed."
378402
)
379-
380-
Y_utf = self.means + self.stdvs * Y
381-
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
403+
include_stdvs_sq = Yvar is not None
404+
means, stdvs, stdvs_sq = self._get_per_input_means_stdvs(
405+
X=X, include_stdvs_sq=include_stdvs_sq
406+
)
407+
Y_utf = means + stdvs * Y
408+
Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
382409
return Y_utf, Yvar_utf
383410

384411
@property
@@ -433,8 +460,9 @@ def untransform_posterior(
433460
)
434461
# GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
435462
mvn = posterior.distribution
436-
offset = self.means
437-
scale_fac = self.stdvs
463+
offset, scale_fac, _ = self._get_per_input_means_stdvs(
464+
X=X, include_stdvs_sq=False
465+
)
438466
if not posterior._is_mt:
439467
mean_tf = offset.squeeze(-1) + scale_fac.squeeze(-1) * mvn.mean
440468
scale_fac = scale_fac.squeeze(-1).expand_as(mean_tf)
@@ -449,7 +477,7 @@ def untransform_posterior(
449477

450478
if (
451479
not mvn.islazy
452-
# TODO: Figure out attribute namming weirdness here
480+
# TODO: Figure out attribute naming weirdness here
453481
or mvn._MultivariateNormal__unbroadcasted_scale_tril is not None
454482
):
455483
# if already computed, we can save a lot of time using scale_tril
@@ -465,6 +493,213 @@ def untransform_posterior(
465493
return GPyTorchPosterior(mvn_tf)
466494

467495

496+
class StratifiedStandardize(Standardize):
497+
r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
498+
499+
This module is stateful: If in train mode, calling forward updates the
500+
module state (i.e. the mean/std normalizing constants). If in eval mode,
501+
calling forward simply applies the standardization using the current module
502+
state.
503+
"""
504+
505+
def __init__(
506+
self,
507+
task_values: Tensor,
508+
stratification_idx: int,
509+
batch_shape: torch.Size = torch.Size(), # noqa: B008
510+
min_stdv: float = 1e-8,
511+
# dtype: torch.dtype = torch.double,
512+
) -> None:
513+
r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
514+
515+
Note: This currenlty only supports single output models
516+
(including multi-task models that have a single output).
517+
518+
Args:
519+
task_values: `t`-dim tensor of task values.
520+
stratification_idx: The index of the stratification dimension.
521+
batch_shape: The batch_shape of the training targets.
522+
min_stddv: The minimum standard deviation for which to perform
523+
standardization (if lower, only de-mean the data).
524+
"""
525+
OutcomeTransform.__init__(self)
526+
self._stratification_idx = stratification_idx
527+
task_values = task_values.unique(sorted=True)
528+
self.strata_mapping = get_task_value_remapping(task_values, dtype=torch.long)
529+
if self.strata_mapping is None:
530+
self.strata_mapping = task_values
531+
n_strata = self.strata_mapping.shape[0]
532+
self._min_stdv = min_stdv
533+
self.register_buffer("means", torch.zeros(*batch_shape, n_strata, 1))
534+
self.register_buffer("stdvs", torch.ones(*batch_shape, n_strata, 1))
535+
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, n_strata, 1))
536+
self.register_buffer("_is_trained", torch.tensor(False))
537+
self._batch_shape = batch_shape
538+
self._m = 1 # TODO: support multiple outputs
539+
self._outputs = None
540+
541+
def forward(
542+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
543+
) -> tuple[Tensor, Tensor | None]:
544+
r"""Standardize outcomes.
545+
546+
If the module is in train mode, this updates the module state (i.e. the
547+
mean/std normalizing constants). If the module is in eval mode, simply
548+
applies the normalization using the module state.
549+
550+
Args:
551+
Y: A `batch_shape x n x m`-dim tensor of training targets.
552+
Yvar: A `batch_shape x n x m`-dim tensor of observation noises
553+
associated with the training targets (if applicable).
554+
X: A `batch_shape x n x d`-dim tensor of input parameters.
555+
556+
Returns:
557+
A two-tuple with the transformed outcomes:
558+
559+
- The transformed outcome observations.
560+
- The transformed observation noise (if applicable).
561+
"""
562+
if X is None:
563+
raise ValueError("X is required for StratifiedStandardize.")
564+
if self.training:
565+
if Y.shape[:-2] != self._batch_shape:
566+
raise RuntimeError(
567+
f"Expected Y.shape[:-2] to be {self._batch_shape}, matching "
568+
"the `batch_shape` argument to `StratifiedStandardize`, but got "
569+
f"Y.shape[:-2]={Y.shape[:-2]}."
570+
)
571+
elif Y.shape[-2] < 1:
572+
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
573+
elif Y.size(-1) != self._m:
574+
raise RuntimeError(
575+
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
576+
f"{self._m}."
577+
)
578+
self.means = self.means.to(dtype=X.dtype, device=X.device)
579+
self.stdvs = self.stdvs.to(dtype=X.dtype, device=X.device)
580+
self._stdvs_sq = self._stdvs_sq.to(dtype=X.dtype, device=X.device)
581+
strata = X[..., self._stratification_idx].long()
582+
unique_strata = strata.unique()
583+
for s in unique_strata:
584+
mapped_strata = self.strata_mapping[s]
585+
mask = strata != s
586+
Y_strata = Y.clone()
587+
Y_strata[..., mask, :] = float("nan")
588+
if Y.shape[-2] == 1:
589+
stdvs = torch.ones(
590+
(*Y_strata.shape[:-2], 1, Y_strata.shape[-1]),
591+
dtype=Y.dtype,
592+
device=Y.device,
593+
)
594+
else:
595+
stdvs = nanstd(X=Y_strata, dim=-2)
596+
stdvs = stdvs.where(
597+
stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)
598+
)
599+
means = Y_strata.nanmean(dim=-2)
600+
self.means[..., mapped_strata, :] = means
601+
self.stdvs[..., mapped_strata, :] = stdvs
602+
self._stdvs_sq[..., mapped_strata, :] = stdvs.pow(2)
603+
self._is_trained = torch.tensor(True)
604+
training = self.training
605+
self.training = False
606+
tf_Y, tf_Yvar = super().forward(Y=Y, Yvar=Yvar, X=X)
607+
self.training = training
608+
return tf_Y, tf_Yvar
609+
610+
def _get_per_input_means_stdvs(
611+
self, X: Tensor, include_stdvs_sq: bool
612+
) -> tuple[Tensor, Tensor, Tensor | None]:
613+
r"""Get per-input means and stdvs.
614+
615+
Args:
616+
X: A `batch_shape x n x d`-dim tensor of input parameters.
617+
include_stdvs_sq: Whether to include the stdvs squared.
618+
619+
Returns:
620+
A three-tuple with the per-input means and stdvs:
621+
622+
- The per-input means.
623+
- The per-input stdvs.
624+
- The per-input stdvs squared.
625+
"""
626+
strata = X[..., self._stratification_idx].long()
627+
mapped_strata = self.strata_mapping[strata].unsqueeze(-1)
628+
# get means and stdvs for each strata
629+
n_extra_batch_dims = mapped_strata.ndim - 2 - len(self._batch_shape)
630+
view_shape = torch.Size([1] * n_extra_batch_dims) + self.means.shape
631+
expand_shape = mapped_strata.shape[:n_extra_batch_dims] + self.means.shape
632+
means = torch.gather(
633+
input=self.means.view(view_shape).expand(expand_shape),
634+
dim=-2,
635+
index=mapped_strata,
636+
)
637+
stdvs = torch.gather(
638+
input=self.stdvs.view(view_shape).expand(expand_shape),
639+
dim=-2,
640+
index=mapped_strata,
641+
)
642+
if include_stdvs_sq:
643+
stdvs_sq = torch.gather(
644+
input=self._stdvs_sq.view(view_shape).expand(expand_shape),
645+
dim=-2,
646+
index=mapped_strata,
647+
)
648+
else:
649+
stdvs_sq = None
650+
return means, stdvs, stdvs_sq
651+
652+
def subset_output(self, idcs: list[int]) -> OutcomeTransform:
653+
r"""Subset the transform along the output dimension.
654+
655+
Args:
656+
idcs: The output indices to subset the transform to.
657+
658+
Returns:
659+
The current outcome transform, subset to the specified output indices.
660+
"""
661+
raise NotImplementedError
662+
663+
def untransform(
664+
self, Y: Tensor, Yvar: Tensor | None = None, X: Tensor | None = None
665+
) -> tuple[Tensor, Tensor | None]:
666+
r"""Un-standardize outcomes.
667+
668+
Args:
669+
Y: A `batch_shape x n x m`-dim tensor of standardized targets.
670+
Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
671+
noises associated with the targets (if applicable).
672+
X: A `batch_shape x n x d`-dim tensor of input parameters.
673+
674+
Returns:
675+
A two-tuple with the un-standardized outcomes:
676+
677+
- The un-standardized outcome observations.
678+
- The un-standardized observation noise (if applicable).
679+
"""
680+
if X is None:
681+
raise ValueError("X is required for StratifiedStandardize.")
682+
return super().untransform(Y=Y, Yvar=Yvar, X=X)
683+
684+
def untransform_posterior(
685+
self, posterior: Posterior, X: Tensor | None = None
686+
) -> GPyTorchPosterior | TransformedPosterior:
687+
r"""Un-standardize the posterior.
688+
689+
Args:
690+
posterior: A posterior in the standardized space.
691+
X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
692+
693+
Returns:
694+
The un-standardized posterior. If the input posterior is a
695+
`GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
696+
`TransformedPosterior`.
697+
"""
698+
if X is None:
699+
raise ValueError("X is required for StratifiedStandardize.")
700+
return super().untransform_posterior(posterior=posterior, X=X)
701+
702+
468703
class Log(OutcomeTransform):
469704
r"""Log-transform outcomes.
470705

botorch/models/transforms/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,12 @@ def interaction_features(X: Tensor) -> Tensor:
141141
dim = X.shape[-1]
142142
row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
143143
return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)
144+
145+
146+
def nanstd(X: Tensor, dim: int, keepdim: bool = False) -> Tensor:
147+
n = (~torch.isnan(X)).sum(dim=dim)
148+
return (
149+
(X - X.nanmean(dim=dim, keepdim=True)).pow(2).nanmean(dim=dim, keepdim=keepdim)
150+
* n
151+
/ (n - 1)
152+
).sqrt()

0 commit comments

Comments
 (0)