Skip to content

Commit 8d4c70b

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
RobustRelevancePursuitSingleTaskGP with specialized fit_gpytorch_mll (#2690)
Summary: This commit introduces an abstract `RobustRelevancePursuitModel` and `RobustRelevancePursuitSingleTaskGP`, a specific implementation of the abstract class. The main purpose of the new class is to provide an identical interface to a canonical `SingleTaskGP`, but automatically extend the likelihood with the `SparseOutlierGaussianLikelihood`, and toggle the Relevance Pursuit algorithm automatically through the marginal likelihood optimization via `fit_gpytorch_mll` by dispatching on the model type. This makes the model and algorithm easy to use. Differential Revision: D68353582
1 parent 7b803bd commit 8d4c70b

File tree

5 files changed

+503
-56
lines changed

5 files changed

+503
-56
lines changed

botorch/models/gp_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
noise=train_Yvar, batch_shape=self._aug_batch_shape
187187
)
188188
else:
189+
# This is used to check if the `model_list_to_batched` can be used
189190
self._is_custom_likelihood = True
190191
ExactGP.__init__(
191192
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood

botorch/models/relevance_pursuit.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import math
2121

2222
from abc import ABC, abstractmethod
23-
from collections.abc import Callable
23+
from collections.abc import Callable, Sequence
2424
from copy import copy, deepcopy
2525
from functools import partial
2626
from typing import Any, cast, Optional
@@ -35,12 +35,13 @@
3535

3636
MLL_ITER = 10_000 # let's take convergence seriously
3737
MLL_TOL = 1e-8
38-
RESET_PARAMETERS = False
38+
RESET_PARAMETERS = True
39+
RESET_DENSE_PARAMETERS = False
3940

4041

4142
class RelevancePursuitMixin(ABC):
4243
"""Mixin class to convert between the sparse and dense representations of the
43-
relevance pursuit models' sparse parameters, as well as to compute the generalized
44+
relevance pursuit modules' sparse parameters, as well as to compute the generalized
4445
support acquisition and support deletion criteria.
4546
"""
4647

@@ -395,7 +396,11 @@ def optimize_mll(
395396
mll: ExactMarginalLogLikelihood,
396397
model_trace: list[Model] | None = None,
397398
reset_parameters: bool = RESET_PARAMETERS,
398-
reset_dense_parameters: bool = RESET_PARAMETERS,
399+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
400+
# fit_gpytorch_mll kwargs
401+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
402+
optimizer: Callable | None = None,
403+
closure_kwargs: dict[str, Any] | None = None,
399404
optimizer_kwargs: dict[str, Any] | None = None,
400405
):
401406
"""Optimizes the marginal likelihood.
@@ -410,6 +415,10 @@ def optimize_mll(
410415
reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411416
other GP hyper-parameters that are *not* part of the Relevance Pursuit
412417
module, to the initial values provided by their associated constraints.
418+
closure: A closure to use to compute the loss and the gradients, see
419+
docstring of `fit_gpytorch_mll` for details.
420+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
421+
closure_kwargs: Additional arguments to pass to the `closure` function.
413422
optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414423
415424
Returns:
@@ -419,7 +428,6 @@ def optimize_mll(
419428
# this might be beneficial because the parameters can
420429
# end up at a constraint boundary, which can anecdotally make
421430
# it more difficult to move the newly added parameters.
422-
# should we only do this after expansion?
423431
with torch.no_grad():
424432
self.sparse_parameter.zero_()
425433

@@ -430,7 +438,13 @@ def optimize_mll(
430438
# NOTE: this function should never force the dense representation, because some
431439
# models might never need it, and it would be inefficient.
432440
self.to_sparse()
433-
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
441+
mll = fit_gpytorch_mll(
442+
mll,
443+
optimizer_kwargs=optimizer_kwargs,
444+
closure=closure,
445+
optimizer=optimizer,
446+
closure_kwargs=closure_kwargs,
447+
)
434448
if model_trace is not None:
435449
# need to record the full model here, rather than just the sparse parameter
436450
# since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +457,15 @@ def forward_relevance_pursuit(
443457
sparse_module: RelevancePursuitMixin,
444458
mll: ExactMarginalLogLikelihood,
445459
sparsity_levels: list[int] | None = None,
446-
optimizer_kwargs: dict[str, Any] | None = None,
447460
reset_parameters: bool = RESET_PARAMETERS,
448-
reset_dense_parameters: bool = RESET_PARAMETERS,
461+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
449462
record_model_trace: bool = True,
450463
initial_support: list[int] | None = None,
464+
# fit_gpytorch_mll kwargs
465+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
466+
optimizer: Callable | None = None,
467+
closure_kwargs: dict[str, Any] | None = None,
468+
optimizer_kwargs: dict[str, Any] | None = None,
451469
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
452470
"""Forward Relevance Pursuit.
453471
@@ -478,9 +496,6 @@ def forward_relevance_pursuit(
478496
sparse_module: The relevance pursuit module.
479497
mll: The marginal likelihood, containing the model to optimize.
480498
sparsity_levels: The sparsity levels to expand the support to.
481-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482-
By default, initializes the "options" sub-dictionary with `maxiter` and
483-
`ftol`, `gtol` values, unless specified.
484499
reset_parameters: If true, initializes the sparse parameter to the all zeros
485500
after each iteration.
486501
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +504,13 @@ def forward_relevance_pursuit(
489504
record_model_trace: If true, records the model state after every iteration.
490505
initial_support: The support with which to initialize the sparse module. By
491506
default, the support is initialized to the empty set.
507+
closure: A closure to use to compute the loss and the gradients, see docstring
508+
of `fit_gpytorch_mll` for details.
509+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
510+
closure_kwargs: Additional arguments to pass to the `closure` function.
511+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
512+
By default, initializes the "options" sub-dictionary with `maxiter` and
513+
`ftol`, `gtol` values, unless specified.
492514
493515
Returns:
494516
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -510,14 +532,17 @@ def forward_relevance_pursuit(
510532

511533
model_trace = [] if record_model_trace else None
512534

513-
def optimize_mll(mll):
514-
return sparse_module.optimize_mll(
515-
mll=mll,
516-
model_trace=model_trace,
517-
reset_parameters=reset_parameters,
518-
reset_dense_parameters=reset_dense_parameters,
519-
optimizer_kwargs=optimizer_kwargs,
520-
)
535+
optimize_mll = partial(
536+
sparse_module.optimize_mll,
537+
model_trace=model_trace,
538+
reset_parameters=reset_parameters,
539+
reset_dense_parameters=reset_dense_parameters,
540+
# These are the args of the canonical mll fit routine
541+
closure=closure,
542+
optimizer=optimizer,
543+
closure_kwargs=closure_kwargs,
544+
optimizer_kwargs=optimizer_kwargs,
545+
)
521546

522547
# if sparsity levels contains the initial support, remove it
523548
if sparsity_levels[0] == len(sparse_module.support):
@@ -548,11 +573,15 @@ def backward_relevance_pursuit(
548573
sparse_module: RelevancePursuitMixin,
549574
mll: ExactMarginalLogLikelihood,
550575
sparsity_levels: list[int] | None = None,
551-
optimizer_kwargs: dict[str, Any] | None = None,
552576
reset_parameters: bool = RESET_PARAMETERS,
553-
reset_dense_parameters: bool = RESET_PARAMETERS,
577+
reset_dense_parameters: bool = RESET_DENSE_PARAMETERS,
554578
record_model_trace: bool = True,
555579
initial_support: list[int] | None = None,
580+
# fit_gpytorch_mll kwargs
581+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
582+
optimizer: Callable | None = None,
583+
closure_kwargs: dict[str, Any] | None = None,
584+
optimizer_kwargs: dict[str, Any] | None = None,
556585
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
557586
"""Backward Relevance Pursuit.
558587
@@ -583,9 +612,6 @@ def backward_relevance_pursuit(
583612
sparse_module: The relevance pursuit module.
584613
mll: The marginal likelihood, containing the model to optimize.
585614
sparsity_levels: The sparsity levels to expand the support to.
586-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587-
By default, initializes the "options" sub-dictionary with `maxiter` and
588-
`ftol`, `gtol` values, unless specified.
589615
reset_parameters: If true, initializes the sparse parameter to the all zeros
590616
after each iteration.
591617
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +620,13 @@ def backward_relevance_pursuit(
594620
record_model_trace: If true, records the model state after every iteration.
595621
initial_support: The support with which to initialize the sparse module. By
596622
default, the support is initialized to the full set.
623+
closure: A closure to use to compute the loss and the gradients, see docstring
624+
of `fit_gpytorch_mll` for details.
625+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
626+
closure_kwargs: Additional arguments to pass to the `closure` function.
627+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
628+
By default, initializes the "options" sub-dictionary with `maxiter` and
629+
`ftol`, `gtol` values, unless specified.
597630
598631
Returns:
599632
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +656,10 @@ def optimize_mll(mll):
623656
model_trace=model_trace,
624657
reset_parameters=reset_parameters,
625658
reset_dense_parameters=reset_dense_parameters,
659+
# These are the args of the canonical mll fit routine
660+
closure=closure,
661+
optimizer=optimizer,
662+
closure_kwargs=closure_kwargs,
626663
optimizer_kwargs=optimizer_kwargs,
627664
)
628665

0 commit comments

Comments
 (0)