Skip to content

Commit 9acb78c

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
RobustRelevancePursuitSingleTaskGP with specialized fit_gpytorch_mll
Summary: This commit introduces an abstract `RobustRelevancePursuitModel` and `RobustRelevancePursuitSingleTaskGP`, a specific implementation of the abstract class. The main purpose of the new class is to provide an identical interface to a canonical `SingleTaskGP`, but automatically extend the likelihood with the `SparseOutlierGaussianLikelihood`, and toggle the Relevance Pursuit algorithm automatically through the marginal likelihood optimization via `fit_gpytorch_mll` by dispatching on the model type. This makes the model and algorithm easy to use. Differential Revision: D68353582
1 parent 589260b commit 9acb78c

File tree

4 files changed

+388
-42
lines changed

4 files changed

+388
-42
lines changed

botorch/models/gp_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
noise=train_Yvar, batch_shape=self._aug_batch_shape
187187
)
188188
else:
189+
# This is used to check if the `model_list_to_batched` can be used
189190
self._is_custom_likelihood = True
190191
ExactGP.__init__(
191192
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood

botorch/models/relevance_pursuit.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import math
2121

2222
from abc import ABC, abstractmethod
23-
from collections.abc import Callable
23+
from collections.abc import Callable, Sequence
2424
from copy import copy, deepcopy
2525
from functools import partial
2626
from typing import Any, cast, Optional
@@ -396,6 +396,10 @@ def optimize_mll(
396396
model_trace: list[Model] | None = None,
397397
reset_parameters: bool = RESET_PARAMETERS,
398398
reset_dense_parameters: bool = RESET_PARAMETERS,
399+
# fit_gpytorch_mll kwargs
400+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
401+
optimizer: Callable | None = None,
402+
closure_kwargs: dict[str, Any] | None = None,
399403
optimizer_kwargs: dict[str, Any] | None = None,
400404
):
401405
"""Optimizes the marginal likelihood.
@@ -410,6 +414,10 @@ def optimize_mll(
410414
reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411415
other GP hyper-parameters that are *not* part of the Relevance Pursuit
412416
module, to the initial values provided by their associated constraints.
417+
closure: A closure to use to compute the loss and the gradients, see
418+
docstring of `fit_gpytorch_mll` for details.
419+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
420+
closure_kwargs: Additional arguments to pass to the `closure` function.
413421
optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414422
415423
Returns:
@@ -430,7 +438,13 @@ def optimize_mll(
430438
# NOTE: this function should never force the dense representation, because some
431439
# models might never need it, and it would be inefficient.
432440
self.to_sparse()
433-
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
441+
fit_gpytorch_mll(
442+
mll,
443+
optimizer_kwargs=optimizer_kwargs,
444+
closure=closure,
445+
optimizer=optimizer,
446+
closure_kwargs=closure_kwargs,
447+
)
434448
if model_trace is not None:
435449
# need to record the full model here, rather than just the sparse parameter
436450
# since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +457,15 @@ def forward_relevance_pursuit(
443457
sparse_module: RelevancePursuitMixin,
444458
mll: ExactMarginalLogLikelihood,
445459
sparsity_levels: list[int] | None = None,
446-
optimizer_kwargs: dict[str, Any] | None = None,
447460
reset_parameters: bool = RESET_PARAMETERS,
448461
reset_dense_parameters: bool = RESET_PARAMETERS,
449462
record_model_trace: bool = True,
450463
initial_support: list[int] | None = None,
464+
# fit_gpytorch_mll kwargs
465+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
466+
optimizer: Callable | None = None,
467+
closure_kwargs: dict[str, Any] | None = None,
468+
optimizer_kwargs: dict[str, Any] | None = None,
451469
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
452470
"""Forward Relevance Pursuit.
453471
@@ -478,9 +496,6 @@ def forward_relevance_pursuit(
478496
sparse_module: The relevance pursuit module.
479497
mll: The marginal likelihood, containing the model to optimize.
480498
sparsity_levels: The sparsity levels to expand the support to.
481-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482-
By default, initializes the "options" sub-dictionary with `maxiter` and
483-
`ftol`, `gtol` values, unless specified.
484499
reset_parameters: If true, initializes the sparse parameter to the all zeros
485500
after each iteration.
486501
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +504,13 @@ def forward_relevance_pursuit(
489504
record_model_trace: If true, records the model state after every iteration.
490505
initial_support: The support with which to initialize the sparse module. By
491506
default, the support is initialized to the empty set.
507+
closure: A closure to use to compute the loss and the gradients, see docstring
508+
of `fit_gpytorch_mll` for details.
509+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
510+
closure_kwargs: Additional arguments to pass to the `closure` function.
511+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
512+
By default, initializes the "options" sub-dictionary with `maxiter` and
513+
`ftol`, `gtol` values, unless specified.
492514
493515
Returns:
494516
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -516,6 +538,10 @@ def optimize_mll(mll):
516538
model_trace=model_trace,
517539
reset_parameters=reset_parameters,
518540
reset_dense_parameters=reset_dense_parameters,
541+
# These are the args of the canonical mll fit routine
542+
closure=closure,
543+
optimizer=optimizer,
544+
closure_kwargs=closure_kwargs,
519545
optimizer_kwargs=optimizer_kwargs,
520546
)
521547

@@ -548,11 +574,15 @@ def backward_relevance_pursuit(
548574
sparse_module: RelevancePursuitMixin,
549575
mll: ExactMarginalLogLikelihood,
550576
sparsity_levels: list[int] | None = None,
551-
optimizer_kwargs: dict[str, Any] | None = None,
552577
reset_parameters: bool = RESET_PARAMETERS,
553578
reset_dense_parameters: bool = RESET_PARAMETERS,
554579
record_model_trace: bool = True,
555580
initial_support: list[int] | None = None,
581+
# fit_gpytorch_mll kwargs
582+
closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]] | None = None,
583+
optimizer: Callable | None = None,
584+
closure_kwargs: dict[str, Any] | None = None,
585+
optimizer_kwargs: dict[str, Any] | None = None,
556586
) -> tuple[RelevancePursuitMixin, Optional[list[Model]]]:
557587
"""Backward Relevance Pursuit.
558588
@@ -583,9 +613,6 @@ def backward_relevance_pursuit(
583613
sparse_module: The relevance pursuit module.
584614
mll: The marginal likelihood, containing the model to optimize.
585615
sparsity_levels: The sparsity levels to expand the support to.
586-
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587-
By default, initializes the "options" sub-dictionary with `maxiter` and
588-
`ftol`, `gtol` values, unless specified.
589616
reset_parameters: If true, initializes the sparse parameter to the all zeros
590617
after each iteration.
591618
reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +621,13 @@ def backward_relevance_pursuit(
594621
record_model_trace: If true, records the model state after every iteration.
595622
initial_support: The support with which to initialize the sparse module. By
596623
default, the support is initialized to the full set.
624+
closure: A closure to use to compute the loss and the gradients, see docstring
625+
of `fit_gpytorch_mll` for details.
626+
optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
627+
closure_kwargs: Additional arguments to pass to the `closure` function.
628+
optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
629+
By default, initializes the "options" sub-dictionary with `maxiter` and
630+
`ftol`, `gtol` values, unless specified.
597631
598632
Returns:
599633
The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +657,10 @@ def optimize_mll(mll):
623657
model_trace=model_trace,
624658
reset_parameters=reset_parameters,
625659
reset_dense_parameters=reset_dense_parameters,
660+
# These are the args of the canonical mll fit routine
661+
closure=closure,
662+
optimizer=optimizer,
663+
closure_kwargs=closure_kwargs,
626664
optimizer_kwargs=optimizer_kwargs,
627665
)
628666

0 commit comments

Comments
 (0)