2727
2828import torch
2929from botorch .models .transforms .utils import (
30+ nanstd ,
3031 norm_to_lognorm_mean ,
3132 norm_to_lognorm_variance ,
3233)
34+ from botorch .models .utils .assorted import get_task_value_remapping
3335from botorch .posteriors import GPyTorchPosterior , Posterior , TransformedPosterior
3436from botorch .utils .transforms import normalize_indices
3537from linear_operator .operators import CholLinearOperator , DiagLinearOperator
@@ -259,6 +261,46 @@ def __init__(
259261 self ._batch_shape = batch_shape
260262 self ._min_stdv = min_stdv
261263
264+ def _get_per_input_means_stdvs (
265+ self , X : Tensor , include_stdvs_sq : bool
266+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
267+ r"""Get per-input means and stdvs.
268+
269+ Args:
270+ X: A `batch_shape x n x d`-dim tensor of input parameters.
271+ include_stdvs_sq: Whether to include the stdvs squared.
272+ This parameter is not used by this method
273+
274+ Returns:
275+ A three-tuple with the means and stdvs:
276+
277+ - The per-input means.
278+ - The per-input stdvs.
279+ - The per-input stdvs squared.
280+ """
281+ return self .means , self .stdvs , self ._stdvs_sq
282+
283+ def _validate_training_inputs (self , Y : Tensor , Yvar : Tensor | None = None ) -> None :
284+ """Validate training inputs.
285+
286+ Args:
287+ Y: A `batch_shape x n x m`-dim tensor of training targets.
288+ Yvar: A `batch_shape x n x m`-dim tensor of observation noises.
289+ """
290+ if Y .shape [:- 2 ] != self ._batch_shape :
291+ raise RuntimeError (
292+ f"Expected Y.shape[:-2] to be { self ._batch_shape } , matching "
293+ f"the `batch_shape` argument to `{ self .__class__ .__name__ } `, but got "
294+ f"Y.shape[:-2]={ Y .shape [:- 2 ]} ."
295+ )
296+ elif Y .shape [- 2 ] < 1 :
297+ raise ValueError (f"Can't standardize with no observations. { Y .shape = } ." )
298+ elif Y .size (- 1 ) != self ._m :
299+ raise RuntimeError (
300+ f"Wrong output dimension. Y.size(-1) is { Y .size (- 1 )} ; expected "
301+ f"{ self ._m } ."
302+ )
303+
262304 def forward (
263305 self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
264306 ) -> tuple [Tensor , Tensor | None ]:
@@ -283,21 +325,8 @@ def forward(
283325 - The transformed observation noise (if applicable).
284326 """
285327 if self .training :
286- if Y .shape [:- 2 ] != self ._batch_shape :
287- raise RuntimeError (
288- f"Expected Y.shape[:-2] to be { self ._batch_shape } , matching "
289- "the `batch_shape` argument to `Standardize`, but got "
290- f"Y.shape[:-2]={ Y .shape [:- 2 ]} ."
291- )
292- if Y .size (- 1 ) != self ._m :
293- raise RuntimeError (
294- f"Wrong output dimension. Y.size(-1) is { Y .size (- 1 )} ; expected "
295- f"{ self ._m } ."
296- )
297- if Y .shape [- 2 ] < 1 :
298- raise ValueError (f"Can't standardize with no observations. { Y .shape = } ." )
299-
300- elif Y .shape [- 2 ] == 1 :
328+ self ._validate_training_inputs (Y = Y , Yvar = Yvar )
329+ if Y .shape [- 2 ] == 1 :
301330 stdvs = torch .ones (
302331 (* Y .shape [:- 2 ], 1 , Y .shape [- 1 ]), dtype = Y .dtype , device = Y .device
303332 )
@@ -313,9 +342,12 @@ def forward(
313342 self .stdvs = stdvs
314343 self ._stdvs_sq = stdvs .pow (2 )
315344 self ._is_trained = torch .tensor (True )
316-
317- Y_tf = (Y - self .means ) / self .stdvs
318- Yvar_tf = Yvar / self ._stdvs_sq if Yvar is not None else None
345+ include_stdvs_sq = Yvar is not None
346+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
347+ X = X , include_stdvs_sq = include_stdvs_sq
348+ )
349+ Y_tf = (Y - means ) / stdvs
350+ Yvar_tf = Yvar / stdvs_sq if include_stdvs_sq else None
319351 return Y_tf , Yvar_tf
320352
321353 def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
@@ -376,9 +408,12 @@ def untransform(
376408 "(e.g. `transform(Y)`) before calling `untransform`, since "
377409 "means and standard deviations need to be computed."
378410 )
379-
380- Y_utf = self .means + self .stdvs * Y
381- Yvar_utf = self ._stdvs_sq * Yvar if Yvar is not None else None
411+ include_stdvs_sq = Yvar is not None
412+ means , stdvs , stdvs_sq = self ._get_per_input_means_stdvs (
413+ X = X , include_stdvs_sq = include_stdvs_sq
414+ )
415+ Y_utf = means + stdvs * Y
416+ Yvar_utf = stdvs_sq * Yvar if include_stdvs_sq else None
382417 return Y_utf , Yvar_utf
383418
384419 @property
@@ -433,8 +468,9 @@ def untransform_posterior(
433468 )
434469 # GPyTorchPosterior (TODO: Should we Lazy-evaluate the mean here as well?)
435470 mvn = posterior .distribution
436- offset = self .means
437- scale_fac = self .stdvs
471+ offset , scale_fac , _ = self ._get_per_input_means_stdvs (
472+ X = X , include_stdvs_sq = False
473+ )
438474 if not posterior ._is_mt :
439475 mean_tf = offset .squeeze (- 1 ) + scale_fac .squeeze (- 1 ) * mvn .mean
440476 scale_fac = scale_fac .squeeze (- 1 ).expand_as (mean_tf )
@@ -449,7 +485,6 @@ def untransform_posterior(
449485
450486 if (
451487 not mvn .islazy
452- # TODO: Figure out attribute namming weirdness here
453488 or mvn ._MultivariateNormal__unbroadcasted_scale_tril is not None
454489 ):
455490 # if already computed, we can save a lot of time using scale_tril
@@ -465,6 +500,197 @@ def untransform_posterior(
465500 return GPyTorchPosterior (mvn_tf )
466501
467502
503+ class StratifiedStandardize (Standardize ):
504+ r"""Standardize outcomes (zero mean, unit variance) along stratification dimension.
505+
506+ This module is stateful: If in train mode, calling forward updates the
507+ module state (i.e. the mean/std normalizing constants). If in eval mode,
508+ calling forward simply applies the standardization using the current module
509+ state.
510+ """
511+
512+ def __init__ (
513+ self ,
514+ task_values : Tensor ,
515+ stratification_idx : int ,
516+ batch_shape : torch .Size = torch .Size (), # noqa: B008
517+ min_stdv : float = 1e-8 ,
518+ # dtype: torch.dtype = torch.double,
519+ ) -> None :
520+ r"""Standardize outcomes (zero mean, unit variance) along stratification dim.
521+
522+ Note: This currenlty only supports single output models
523+ (including multi-task models that have a single output).
524+
525+ Args:
526+ task_values: `t`-dim tensor of task values.
527+ stratification_idx: The index of the stratification dimension.
528+ batch_shape: The batch_shape of the training targets.
529+ min_stddv: The minimum standard deviation for which to perform
530+ standardization (if lower, only de-mean the data).
531+ """
532+ OutcomeTransform .__init__ (self )
533+ self ._stratification_idx = stratification_idx
534+ task_values = task_values .unique (sorted = True )
535+ self .strata_mapping = get_task_value_remapping (task_values , dtype = torch .long )
536+ if self .strata_mapping is None :
537+ self .strata_mapping = task_values
538+ n_strata = self .strata_mapping .shape [0 ]
539+ self ._min_stdv = min_stdv
540+ self .register_buffer ("means" , torch .zeros (* batch_shape , n_strata , 1 ))
541+ self .register_buffer ("stdvs" , torch .ones (* batch_shape , n_strata , 1 ))
542+ self .register_buffer ("_stdvs_sq" , torch .ones (* batch_shape , n_strata , 1 ))
543+ self .register_buffer ("_is_trained" , torch .tensor (False ))
544+ self ._batch_shape = batch_shape
545+ self ._m = 1 # TODO: support multiple outputs
546+ self ._outputs = None
547+
548+ def forward (
549+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
550+ ) -> tuple [Tensor , Tensor | None ]:
551+ r"""Standardize outcomes.
552+
553+ If the module is in train mode, this updates the module state (i.e. the
554+ mean/std normalizing constants). If the module is in eval mode, simply
555+ applies the normalization using the module state.
556+
557+ Args:
558+ Y: A `batch_shape x n x m`-dim tensor of training targets.
559+ Yvar: A `batch_shape x n x m`-dim tensor of observation noises
560+ associated with the training targets (if applicable).
561+ X: A `batch_shape x n x d`-dim tensor of input parameters.
562+
563+ Returns:
564+ A two-tuple with the transformed outcomes:
565+
566+ - The transformed outcome observations.
567+ - The transformed observation noise (if applicable).
568+ """
569+ if X is None :
570+ raise ValueError ("X is required for StratifiedStandardize." )
571+ if self .training :
572+ self ._validate_training_inputs (Y = Y , Yvar = Yvar )
573+ self .means = self .means .to (dtype = X .dtype , device = X .device )
574+ self .stdvs = self .stdvs .to (dtype = X .dtype , device = X .device )
575+ self ._stdvs_sq = self ._stdvs_sq .to (dtype = X .dtype , device = X .device )
576+ strata = X [..., self ._stratification_idx ].long ()
577+ unique_strata = strata .unique ()
578+ for s in unique_strata :
579+ mapped_strata = self .strata_mapping [s ]
580+ mask = strata != s
581+ Y_strata = Y .clone ()
582+ Y_strata [..., mask , :] = float ("nan" )
583+ stdvs = (
584+ torch .ones_like (Y_strata )
585+ if Y .shape [- 2 ] == 1
586+ else nanstd (X = Y_strata , dim = - 2 )
587+ )
588+ stdvs = stdvs .where (
589+ stdvs >= self ._min_stdv , torch .full_like (stdvs , 1.0 )
590+ )
591+ means = Y_strata .nanmean (dim = - 2 )
592+ self .means [..., mapped_strata , :] = means
593+ self .stdvs [..., mapped_strata , :] = stdvs
594+ self ._stdvs_sq [..., mapped_strata , :] = stdvs .pow (2 )
595+ self ._is_trained = torch .tensor (True )
596+ training = self .training
597+ self .training = False
598+ tf_Y , tf_Yvar = super ().forward (Y = Y , Yvar = Yvar , X = X )
599+ self .training = training
600+ return tf_Y , tf_Yvar
601+
602+ def _get_per_input_means_stdvs (
603+ self , X : Tensor , include_stdvs_sq : bool
604+ ) -> tuple [Tensor , Tensor , Tensor | None ]:
605+ r"""Get per-input means and stdvs.
606+
607+ Args:
608+ X: A `batch_shape x n x d`-dim tensor of input parameters.
609+ include_stdvs_sq: Whether to include the stdvs squared.
610+
611+ Returns:
612+ A three-tuple with the per-input means and stdvs:
613+
614+ - The per-input means.
615+ - The per-input stdvs.
616+ - The per-input stdvs squared.
617+ """
618+ strata = X [..., self ._stratification_idx ].long ()
619+ mapped_strata = self .strata_mapping [strata ].unsqueeze (- 1 )
620+ # get means and stdvs for each strata
621+ n_extra_batch_dims = mapped_strata .ndim - 2 - len (self ._batch_shape )
622+ expand_shape = mapped_strata .shape [:n_extra_batch_dims ] + self .means .shape
623+ means = torch .gather (
624+ input = self .means .expand (expand_shape ),
625+ dim = - 2 ,
626+ index = mapped_strata ,
627+ )
628+ stdvs = torch .gather (
629+ input = self .stdvs .expand (expand_shape ),
630+ dim = - 2 ,
631+ index = mapped_strata ,
632+ )
633+ if include_stdvs_sq :
634+ stdvs_sq = torch .gather (
635+ input = self ._stdvs_sq .expand (expand_shape ),
636+ dim = - 2 ,
637+ index = mapped_strata ,
638+ )
639+ else :
640+ stdvs_sq = None
641+ return means , stdvs , stdvs_sq
642+
643+ def subset_output (self , idcs : list [int ]) -> OutcomeTransform :
644+ r"""Subset the transform along the output dimension.
645+
646+ Args:
647+ idcs: The output indices to subset the transform to.
648+
649+ Returns:
650+ The current outcome transform, subset to the specified output indices.
651+ """
652+ raise NotImplementedError
653+
654+ def untransform (
655+ self , Y : Tensor , Yvar : Tensor | None = None , X : Tensor | None = None
656+ ) -> tuple [Tensor , Tensor | None ]:
657+ r"""Un-standardize outcomes.
658+
659+ Args:
660+ Y: A `batch_shape x n x m`-dim tensor of standardized targets.
661+ Yvar: A `batch_shape x n x m`-dim tensor of standardized observation
662+ noises associated with the targets (if applicable).
663+ X: A `batch_shape x n x d`-dim tensor of input parameters.
664+
665+ Returns:
666+ A two-tuple with the un-standardized outcomes:
667+
668+ - The un-standardized outcome observations.
669+ - The un-standardized observation noise (if applicable).
670+ """
671+ if X is None :
672+ raise ValueError ("X is required for StratifiedStandardize." )
673+ return super ().untransform (Y = Y , Yvar = Yvar , X = X )
674+
675+ def untransform_posterior (
676+ self , posterior : Posterior , X : Tensor | None = None
677+ ) -> GPyTorchPosterior | TransformedPosterior :
678+ r"""Un-standardize the posterior.
679+
680+ Args:
681+ posterior: A posterior in the standardized space.
682+ X: A `batch_shape x n x d`-dim tensor of training inputs (if applicable).
683+
684+ Returns:
685+ The un-standardized posterior. If the input posterior is a
686+ `GPyTorchPosterior`, return a `GPyTorchPosterior`. Otherwise, return a
687+ `TransformedPosterior`.
688+ """
689+ if X is None :
690+ raise ValueError ("X is required for StratifiedStandardize." )
691+ return super ().untransform_posterior (posterior = posterior , X = X )
692+
693+
468694class Log (OutcomeTransform ):
469695 r"""Log-transform outcomes.
470696
0 commit comments