2727
2828import torch
2929from botorch .models .transforms .utils import (
30+ nanstd ,
3031 norm_to_lognorm_mean ,
3132 norm_to_lognorm_variance ,
3233)
34+ from botorch .models .utils .assorted import get_task_value_remapping
3335from botorch .posteriors import GPyTorchPosterior , Posterior , TransformedPosterior
3436from botorch .utils .transforms import normalize_indices
3537from linear_operator .operators import CholLinearOperator , DiagLinearOperator
@@ -259,6 +261,25 @@ def __init__(
259261 self ._batch_shape = batch_shape
260262 self ._min_stdv = min_stdv
261263
264+ def _get_per_input_means_stdvs (
265+ self , X : Tensor , include_stdvs_sq : bool
266+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
267+ r"""Get per-input means and stdvs.
268+
269+ Args:
270+ X: A `batch_shape x n x d`-dim tensor of input parameters.
271+ include_stdvs_sq: Whether to include the stdvs squared.
272+ This parameter is not used by this method
273+
274+ Returns:
275+ A three-tuple with the means and stdvs:
276+
277+ - The per-input means.
278+ - The per-input stdvs.
279+ - The per-input stdvs squared.
280+ """
281+ return self .means , self .stdvs , self ._stdvs_sq
282+
262283 def forward (
263284 self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
264285 ) -> tuple [Tensor , Tensor | None ]:
@@ -313,9 +334,12 @@ def forward(
313334 self .stdvs = stdvs
314335 self ._stdvs_sq = stdvs .pow (2 )
315336 self ._is_trained = torch .tensor (True )
316-
317- Y_tf = (Y - self .means ) / self .stdvs
318- Yvar_tf = Yvar / self ._stdvs_sq if Yvar is not None else None
337+ include_stdvs_sq = Yvar is not None
338+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
339+ X = X , include_stdvs_sq = include_stdvs_sq
340+ )
341+ Y_tf = (Y - means ) / stdvs
342+ Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
319343 return Y_tf , Yvar_tf
320344
321345 def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
@@ -376,9 +400,12 @@ def untransform(
376400 "(e.g. `transform(Y)`) before calling `untransform`, since "
377401 "means and standard deviations need to be computed."
378402 )
379-
380- Y_utf = self .means + self .stdvs * Y
381- Yvar_utf = self ._stdvs_sq * Yvar if Yvar is not None else None
403+ include_stdvs_sq = Yvar is not None
404+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
405+ X = X , include_stdvs_sq = include_stdvs_sq
406+ )
407+ Y_utf = means + stdvs * Y
408+ Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
382409 return Y_utf , Yvar_utf
383410
384411 @property
@@ -433,8 +460,9 @@ def untransform_posterior(
433460 )
434461 # GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
435462 mvn = posterior .distribution
436- offset = self .means
437- scale_fac = self .stdvs
463+ offset , scale_fac , _ = self ._get_per_input_means_stdvs (
464+ X = X , include_stdvs_sq = False
465+ )
438466 if not posterior ._is_mt :
439467 mean_tf = offset .squeeze (- 1 ) + scale_fac .squeeze (- 1 ) * mvn .mean
440468 scale_fac = scale_fac .squeeze (- 1 ).expand_as (mean_tf )
@@ -449,7 +477,7 @@ def untransform_posterior(
449477
450478 if (
451479 not mvn .islazy
452- # TODO: Figure out attribute namming weirdness here
480+ # TODO: Figure out attribute naming weirdness here
453481 or mvn ._MultivariateNormal__unbroadcasted_scale_tril is not None
454482 ):
455483 # if already computed, we can save a lot of time using scale_tril
@@ -465,6 +493,213 @@ def untransform_posterior(
465493 return GPyTorchPosterior (mvn_tf )
466494
467495
496+ class StratifiedStandardize (Standardize ):
497+ r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
498+
499+ This module is stateful: If in train mode, calling forward updates the
500+ module state (i.e. the mean/std normalizing constants). If in eval mode,
501+ calling forward simply applies the standardization using the current module
502+ state.
503+ """
504+
505+ def __init__ (
506+ self ,
507+ task_values : Tensor ,
508+ stratification_idx : int ,
509+ batch_shape : torch .Size = torch .Size (), # noqa: B008
510+ min_stdv : float = 1e-8 ,
511+ # dtype: torch.dtype = torch.double,
512+ ) -> None :
513+ r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
514+
515+ Note: This currenlty only supports single output models
516+ (including multi-task models that have a single output).
517+
518+ Args:
519+ task_values: `t`-dim tensor of task values.
520+ stratification_idx: The index of the stratification dimension.
521+ batch_shape: The batch_shape of the training targets.
522+ min_stddv: The minimum standard deviation for which to perform
523+ standardization (if lower, only de-mean the data).
524+ """
525+ OutcomeTransform .__init__ (self )
526+ self ._stratification_idx = stratification_idx
527+ task_values = task_values .unique (sorted = True )
528+ self .strata_mapping = get_task_value_remapping (task_values , dtype = torch .long )
529+ if self .strata_mapping is None :
530+ self .strata_mapping = task_values
531+ n_strata = self .strata_mapping .shape [0 ]
532+ self ._min_stdv = min_stdv
533+ self .register_buffer ("means" , torch .zeros (* batch_shape , n_strata , 1 ))
534+ self .register_buffer ("stdvs" , torch .ones (* batch_shape , n_strata , 1 ))
535+ self .register_buffer ("_stdvs_sq" , torch .ones (* batch_shape , n_strata , 1 ))
536+ self .register_buffer ("_is_trained" , torch .tensor (False ))
537+ self ._batch_shape = batch_shape
538+ self ._m = 1 # TODO: support multiple outputs
539+ self ._outputs = None
540+
541+ def forward (
542+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
543+ ) -> tuple [Tensor , Tensor | None ]:
544+ r"""Standardize outcomes.
545+
546+ If the module is in train mode, this updates the module state (i.e. the
547+ mean/std normalizing constants). If the module is in eval mode, simply
548+ applies the normalization using the module state.
549+
550+ Args:
551+ Y: A `batch_shape x n x m`-dim tensor of training targets.
552+ Yvar: A `batch_shape x n x m`-dim tensor of observation noises
553+ associated with the training targets (if applicable).
554+ X: A `batch_shape x n x d`-dim tensor of input parameters.
555+
556+ Returns:
557+ A two-tuple with the transformed outcomes:
558+
559+ - The transformed outcome observations.
560+ - The transformed observation noise (if applicable).
561+ """
562+ if X is None :
563+ raise ValueError ("X is required for StratifiedStandardize." )
564+ if self .training :
565+ if Y .shape [:- 2 ] != self ._batch_shape :
566+ raise RuntimeError (
567+ f"Expected Y.shape[:-2] to be { self ._batch_shape } , matching "
568+ "the `batch_shape` argument to `StratifiedStandardize`, but got "
569+ f"Y.shape[:-2]={ Y .shape [:- 2 ]} ."
570+ )
571+ elif Y .shape [- 2 ] < 1 :
572+ raise ValueError (f"Can't standardize with no observations. { Y .shape = } ." )
573+ elif Y .size (- 1 ) != self ._m :
574+ raise RuntimeError (
575+ f"Wrong output dimension. Y.size(-1) is { Y .size (- 1 )} ; expected "
576+ f"{ self ._m } ."
577+ )
578+ self .means = self .means .to (dtype = X .dtype , device = X .device )
579+ self .stdvs = self .stdvs .to (dtype = X .dtype , device = X .device )
580+ self ._stdvs_sq = self ._stdvs_sq .to (dtype = X .dtype , device = X .device )
581+ strata = X [..., self ._stratification_idx ].long ()
582+ unique_strata = strata .unique ()
583+ for s in unique_strata :
584+ mapped_strata = self .strata_mapping [s ]
585+ mask = strata != s
586+ Y_strata = Y .clone ()
587+ Y_strata [..., mask , :] = float ("nan" )
588+ if Y .shape [- 2 ] == 1 :
589+ stdvs = torch .ones (
590+ (* Y_strata .shape [:- 2 ], 1 , Y_strata .shape [- 1 ]),
591+ dtype = Y .dtype ,
592+ device = Y .device ,
593+ )
594+ else :
595+ stdvs = nanstd (X = Y_strata , dim = - 2 )
596+ stdvs = stdvs .where (
597+ stdvs >= self ._min_stdv , torch .full_like (stdvs , 1.0 )
598+ )
599+ means = Y_strata .nanmean (dim = - 2 )
600+ self .means [..., mapped_strata , :] = means
601+ self .stdvs [..., mapped_strata , :] = stdvs
602+ self ._stdvs_sq [..., mapped_strata , :] = stdvs .pow (2 )
603+ self ._is_trained = torch .tensor (True )
604+ training = self .training
605+ self .training = False
606+ tf_Y , tf_Yvar = super ().forward (Y = Y , Yvar = Yvar , X = X )
607+ self .training = training
608+ return tf_Y , tf_Yvar
609+
610+ def _get_per_input_means_stdvs (
611+ self , X : Tensor , include_stdvs_sq : bool
612+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
613+ r"""Get per-input means and stdvs.
614+
615+ Args:
616+ X: A `batch_shape x n x d`-dim tensor of input parameters.
617+ include_stdvs_sq: Whether to include the stdvs squared.
618+
619+ Returns:
620+ A three-tuple with the per-input means and stdvs:
621+
622+ - The per-input means.
623+ - The per-input stdvs.
624+ - The per-input stdvs squared.
625+ """
626+ strata = X [..., self ._stratification_idx ].long ()
627+ mapped_strata = self .strata_mapping [strata ].unsqueeze (- 1 )
628+ # get means and stdvs for each strata
629+ n_extra_batch_dims = mapped_strata .ndim - 2 - len (self ._batch_shape )
630+ view_shape = torch .Size ([1 ] * n_extra_batch_dims ) + self .means .shape
631+ expand_shape = mapped_strata .shape [:n_extra_batch_dims ] + self .means .shape
632+ means = torch .gather (
633+ input = self .means .view (view_shape ).expand (expand_shape ),
634+ dim = - 2 ,
635+ index = mapped_strata ,
636+ )
637+ stdvs = torch .gather (
638+ input = self .stdvs .view (view_shape ).expand (expand_shape ),
639+ dim = - 2 ,
640+ index = mapped_strata ,
641+ )
642+ if include_stdvs_sq :
643+ stdvs_sq = torch .gather (
644+ input = self ._stdvs_sq .view (view_shape ).expand (expand_shape ),
645+ dim = - 2 ,
646+ index = mapped_strata ,
647+ )
648+ else :
649+ stdvs_sq = None
650+ return means , stdvs , stdvs_sq
651+
652+ def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
653+ r"""Subset the transform along the output dimension.
654+
655+ Args:
656+ idcs: The output indices to subset the transform to.
657+
658+ Returns:
659+ The current outcome transform, subset to the specified output indices.
660+ """
661+ raise NotImplementedError
662+
663+ def untransform (
664+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
665+ ) -> tuple [Tensor , Tensor | None ]:
666+ r"""Un-standardize outcomes.
667+
668+ Args:
669+ Y: A `batch_shape x n x m`-dim tensor of standardized targets.
670+ Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
671+ noises associated with the targets (if applicable).
672+ X: A `batch_shape x n x d`-dim tensor of input parameters.
673+
674+ Returns:
675+ A two-tuple with the un-standardized outcomes:
676+
677+ - The un-standardized outcome observations.
678+ - The un-standardized observation noise (if applicable).
679+ """
680+ if X is None :
681+ raise ValueError ("X is required for StratifiedStandardize." )
682+ return super ().untransform (Y = Y , Yvar = Yvar , X = X )
683+
684+ def untransform_posterior (
685+ self , posterior : Posterior , X : Tensor | None = None
686+ ) -> GPyTorchPosterior | TransformedPosterior :
687+ r"""Un-standardize the posterior.
688+
689+ Args:
690+ posterior: A posterior in the standardized space.
691+ X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
692+
693+ Returns:
694+ The un-standardized posterior. If the input posterior is a
695+ `GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
696+ `TransformedPosterior`.
697+ """
698+ if X is None :
699+ raise ValueError ("X is required for StratifiedStandardize." )
700+ return super ().untransform_posterior (posterior = posterior , X = X )
701+
702+
468703class Log (OutcomeTransform ):
469704 r"""Log-transform outcomes.
470705
0 commit comments