Skip to content

Commit 2ac2685

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
RobustRelevancePursuitSingleTaskGP with specialized fit_gpytorch_mll (#2690)
Summary: This commit introduces an abstract `RobustRelevancePursuitModel` and `RobustRelevancePursuitSingleTaskGP`, a specific implementation of the abstract class. The main purpose of the new class is to provide an identical interface to a canonical `SingleTaskGP`, but automatically extend the likelihood with the `SparseOutlierGaussianLikelihood`, and toggle the Relevance Pursuit algorithm automatically through the marginal likelihood optimization via `fit_gpytorch_mll` by dispatching on the model type. This makes the model and algorithm easy to use. Reviewed By: esantorella Differential Revision: D68353582
1 parent 851df1f commit 2ac2685

File tree

6 files changed

+547
-65
lines changed

6 files changed

+547
-65
lines changed

botorch/models/gp_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
noise=train_Yvar, batch_shape=self._aug_batch_shape
187187
)
188188
else:
189+
# This is used to check if the `model_list_to_batched` can be used
189190
self._is_custom_likelihood = True
190191
ExactGP.__init__(
191192
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood

botorch/models/relevance_pursuit.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import math
2121

2222
from abc import ABC, abstractmethod
23-
from collections.abc import Callable
23+
from collections.abc import Callable, Sequence
2424
from copy import copy, deepcopy
2525
from functools import partial
2626
from typing import Any, cast, Optional
@@ -35,12 +35,13 @@
3535

3636
MLL_ITER = 10_000 # let's take convergence seriously
3737
MLL_TOL = 1e-8
38-
RESET_PARAMETERS = False
38+
RESET_PARAMETERS = True
39+
RESET_DENSE_PARAMETERS = False
3940

4041

4142
class RelevancePursuitMixin(ABC):
4243
"""Mixin class to convert between the sparse and dense representations of the
43-
relevance pursuit models' sparse parameters, as well as to compute the generalized
44+
relevance pursuit modules' sparse parameters, as well as to compute the generalized
4445
support acquisition and support deletion criteria.
4546
"""
4647

@@ -251,19 +252,21 @@ def support_expansion(
251252
n: int = 1,
252253
modifier: Callable[[Tensor], Tensor] | None = None,
253254
) -> bool:
254-
"""Computes the indices of the features that maximize the gradient of the sparse
255+
"""Computes the indices of the elements that maximize the gradient of the sparse
255256
parameter and that are not already in the support, and subsequently expands the
256-
support to include the features if their gradient is positive.
257+
support to include the elements if their gradient is positive.
257258
258259
Args:
259260
mll: The marginal likelihood, containing the model to optimize.
260261
NOTE: Virtually all of the rest of the code is not specific to the
261262
marginal likelihood optimization, so we could generalize this to work
262263
with any objective.
263-
n: The number of features to select.
264-
modifier: A function that modifies the gradient of the inactive parameters
264+
n: The maximum number of elements to select. NOTE: The actual number of
265+
elements that are added could be fewer if there are fewer than `n`
266+
elements with a positive gradient.
267+
modifier: A function that modifies the gradient of the inactive elements
265268
before computing the support expansion criterion. This can be used
266-
to select the maximum gradient magnitude for real-valued parameters
269+
to select the maximum gradient magnitude for real-valued elements
267270
whose gradients are not non-negative, using modifier = torch.abs.
268271
269272
Returns:
@@ -354,15 +357,15 @@ def support_contraction(
354357
n: int = 1,
355358
modifier: Callable[[Tensor], Tensor] | None = None,
356359
) -> bool:
357-
"""Computes the indices of the features that have the smallest coefficients,
358-
and subsequently contracts the exlude the features.
360+
"""Computes the indices of the elements with the smallest magnitude,
361+
and subsequently contracts the support by exluding the elements.
359362
360363
Args:
361364
mll: The marginal likelihood, containing the model to optimize.
362365
NOTE: Virtually all of the rest of the code is not specific to the
363366
marginal likelihood optimization, so we could generalize this to work
364367
with any objective.
365-
n: The number of features to select for removal.
368+
n: The number of elements to select for removal.
366369
modifier: A function that modifies the parameter values before computing
367370
the support contraction criterion.
368371
@@ -395,7 +398,11 @@ def optimize_mll(
395398
mll: ExactMarginalLogLikelihood,
396399
model_trace: list[Model] | None = None,
397400
reset_parameters: bool = RESET_PARAMETERS,
398-
reset_dense_parameters: bool = RESET_PARAMETERS,
401+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
402+
# fit_gpytorch_mll kwargs
403+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
404+
optimizer: Callable | None = None,
405+
closure_kwargs: dict[str, Any] | None = None,
399406
optimizer_kwargs: dict[str, Any] | None = None,
400407
):
401408
"""Optimizes the marginal likelihood.
@@ -410,6 +417,10 @@ def optimize_mll(
410417
reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411418
other GP hyper-parameters that are *not* part of the Relevance Pursuit
412419
module, to the initial values provided by their associated constraints.
420+
closure: A closure to use to compute the loss and the gradients, see
421+
docstring of `fit_gpytorch_mll` for details.
422+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
423+
closure_kwargs: Additional arguments to pass to the `closure` function.
413424
optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414425
415426
Returns:
@@ -419,7 +430,6 @@ def optimize_mll(
419430
# this might be beneficial because the parameters can
420431
# end up at a constraint boundary, which can anecdotally make
421432
# it more difficult to move the newly added parameters.
422-
# should we only do this after expansion?
423433
with torch.no_grad():
424434
self.sparse_parameter.zero_()
425435

@@ -430,7 +440,13 @@ def optimize_mll(
430440
# NOTE: this function should never force the dense representation, because some
431441
# models might never need it, and it would be inefficient.
432442
self.to_sparse()
433-
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
443+
mll = fit_gpytorch_mll(
444+
mll,
445+
optimizer_kwargs=optimizer_kwargs,
446+
closure=closure,
447+
optimizer=optimizer,
448+
closure_kwargs=closure_kwargs,
449+
)
434450
if model_trace is not None:
435451
# need to record the full model here, rather than just the sparse parameter
436452
# since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +459,15 @@ def forward_relevance_pursuit(
443459
sparse_module: RelevancePursuitMixin,
444460
mll: ExactMarginalLogLikelihood,
445461
sparsity_levels: list[int] | None = None,
446-
optimizer_kwargs: dict[str, Any] | None = None,
447462
reset_parameters: bool = RESET_PARAMETERS,
448-
reset_dense_parameters: bool = RESET_PARAMETERS,
463+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
449464
record_model_trace: bool = True,
450465
initial_support: list[int] | None = None,
466+
# fit_gpytorch_mll kwargs
467+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
468+
optimizer: Callable | None = None,
469+
closure_kwargs: dict[str, Any] | None = None,
470+
optimizer_kwargs: dict[str, Any] | None = None,
451471
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
452472
"""Forward Relevance Pursuit.
453473
@@ -478,9 +498,6 @@ def forward_relevance_pursuit(
478498
sparse_module: The relevance pursuit module.
479499
mll: The marginal likelihood, containing the model to optimize.
480500
sparsity_levels: The sparsity levels to expand the support to.
481-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482-
By default, initializes the "options" sub-dictionary with `maxiter` and
483-
`ftol`, `gtol` values, unless specified.
484501
reset_parameters: If true, initializes the sparse parameter to the all zeros
485502
after each iteration.
486503
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +506,13 @@ def forward_relevance_pursuit(
489506
record_model_trace: If true, records the model state after every iteration.
490507
initial_support: The support with which to initialize the sparse module. By
491508
default, the support is initialized to the empty set.
509+
closure: A closure to use to compute the loss and the gradients, see docstring
510+
of `fit_gpytorch_mll` for details.
511+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
512+
closure_kwargs: Additional arguments to pass to the `closure` function.
513+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
514+
By default, initializes the "options" sub-dictionary with `maxiter` and
515+
`ftol`, `gtol` values, unless specified.
492516
493517
Returns:
494518
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -510,14 +534,17 @@ def forward_relevance_pursuit(
510534

511535
model_trace = [] if record_model_trace else None
512536

513-
def optimize_mll(mll):
514-
return sparse_module.optimize_mll(
515-
mll=mll,
516-
model_trace=model_trace,
517-
reset_parameters=reset_parameters,
518-
reset_dense_parameters=reset_dense_parameters,
519-
optimizer_kwargs=optimizer_kwargs,
520-
)
537+
optimize_mll = partial(
538+
sparse_module.optimize_mll,
539+
model_trace=model_trace,
540+
reset_parameters=reset_parameters,
541+
reset_dense_parameters=reset_dense_parameters,
542+
# These are the args of the canonical mll fit routine
543+
closure=closure,
544+
optimizer=optimizer,
545+
closure_kwargs=closure_kwargs,
546+
optimizer_kwargs=optimizer_kwargs,
547+
)
521548

522549
# if sparsity levels contains the initial support, remove it
523550
if sparsity_levels[0] == len(sparse_module.support):
@@ -548,11 +575,15 @@ def backward_relevance_pursuit(
548575
sparse_module: RelevancePursuitMixin,
549576
mll: ExactMarginalLogLikelihood,
550577
sparsity_levels: list[int] | None = None,
551-
optimizer_kwargs: dict[str, Any] | None = None,
552578
reset_parameters: bool = RESET_PARAMETERS,
553-
reset_dense_parameters: bool = RESET_PARAMETERS,
579+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
554580
record_model_trace: bool = True,
555581
initial_support: list[int] | None = None,
582+
# fit_gpytorch_mll kwargs
583+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
584+
optimizer: Callable | None = None,
585+
closure_kwargs: dict[str, Any] | None = None,
586+
optimizer_kwargs: dict[str, Any] | None = None,
556587
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
557588
"""Backward Relevance Pursuit.
558589
@@ -583,9 +614,6 @@ def backward_relevance_pursuit(
583614
sparse_module: The relevance pursuit module.
584615
mll: The marginal likelihood, containing the model to optimize.
585616
sparsity_levels: The sparsity levels to expand the support to.
586-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587-
By default, initializes the "options" sub-dictionary with `maxiter` and
588-
`ftol`, `gtol` values, unless specified.
589617
reset_parameters: If true, initializes the sparse parameter to the all zeros
590618
after each iteration.
591619
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +622,13 @@ def backward_relevance_pursuit(
594622
record_model_trace: If true, records the model state after every iteration.
595623
initial_support: The support with which to initialize the sparse module. By
596624
default, the support is initialized to the full set.
625+
closure: A closure to use to compute the loss and the gradients, see docstring
626+
of `fit_gpytorch_mll` for details.
627+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
628+
closure_kwargs: Additional arguments to pass to the `closure` function.
629+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
630+
By default, initializes the "options" sub-dictionary with `maxiter` and
631+
`ftol`, `gtol` values, unless specified.
597632
598633
Returns:
599634
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +658,10 @@ def optimize_mll(mll):
623658
model_trace=model_trace,
624659
reset_parameters=reset_parameters,
625660
reset_dense_parameters=reset_dense_parameters,
661+
# These are the args of the canonical mll fit routine
662+
closure=closure,
663+
optimizer=optimizer,
664+
closure_kwargs=closure_kwargs,
626665
optimizer_kwargs=optimizer_kwargs,
627666
)
628667

0 commit comments

Comments
 (0)