2020import math
2121
2222from abc import ABC , abstractmethod
23- from collections .abc import Callable
23+ from collections .abc import Callable , Sequence
2424from copy import copy , deepcopy
2525from functools import partial
2626from typing import Any , cast , Optional
@@ -396,6 +396,10 @@ def optimize_mll(
396396 model_trace : list [Model ] | None = None ,
397397 reset_parameters : bool = RESET_PARAMETERS ,
398398 reset_dense_parameters : bool = RESET_PARAMETERS ,
399+ # fit_gpytorch_mll kwargs
400+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
401+ optimizer : Callable | None = None ,
402+ closure_kwargs : dict [str , Any ] | None = None ,
399403 optimizer_kwargs : dict [str , Any ] | None = None ,
400404 ):
401405 """Optimizes the marginal likelihood.
@@ -410,6 +414,10 @@ def optimize_mll(
410414 reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411415 other GP hyper-parameters that are *not* part of the Relevance Pursuit
412416 module, to the initial values provided by their associated constraints.
417+ closure: A closure to use to compute the loss and the gradients, see
418+ docstring of `fit_gpytorch_mll` for details.
419+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
420+ closure_kwargs: Additional arguments to pass to the `closure` function.
413421 optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414422
415423 Returns:
@@ -430,7 +438,13 @@ def optimize_mll(
430438 # NOTE: this function should never force the dense representation, because some
431439 # models might never need it, and it would be inefficient.
432440 self .to_sparse ()
433- fit_gpytorch_mll (mll , optimizer_kwargs = optimizer_kwargs )
441+ fit_gpytorch_mll (
442+ mll ,
443+ optimizer_kwargs = optimizer_kwargs ,
444+ closure = closure ,
445+ optimizer = optimizer ,
446+ closure_kwargs = closure_kwargs ,
447+ )
434448 if model_trace is not None :
435449 # need to record the full model here, rather than just the sparse parameter
436450 # since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +457,15 @@ def forward_relevance_pursuit(
443457 sparse_module : RelevancePursuitMixin ,
444458 mll : ExactMarginalLogLikelihood ,
445459 sparsity_levels : list [int ] | None = None ,
446- optimizer_kwargs : dict [str , Any ] | None = None ,
447460 reset_parameters : bool = RESET_PARAMETERS ,
448461 reset_dense_parameters : bool = RESET_PARAMETERS ,
449462 record_model_trace : bool = True ,
450463 initial_support : list [int ] | None = None ,
464+ # fit_gpytorch_mll kwargs
465+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
466+ optimizer : Callable | None = None ,
467+ closure_kwargs : dict [str , Any ] | None = None ,
468+ optimizer_kwargs : dict [str , Any ] | None = None ,
451469) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
452470 """Forward Relevance Pursuit.
453471
@@ -478,9 +496,6 @@ def forward_relevance_pursuit(
478496 sparse_module: The relevance pursuit module.
479497 mll: The marginal likelihood, containing the model to optimize.
480498 sparsity_levels: The sparsity levels to expand the support to.
481- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482- By default, initializes the "options" sub-dictionary with `maxiter` and
483- `ftol`, `gtol` values, unless specified.
484499 reset_parameters: If true, initializes the sparse parameter to the all zeros
485500 after each iteration.
486501 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +504,13 @@ def forward_relevance_pursuit(
489504 record_model_trace: If true, records the model state after every iteration.
490505 initial_support: The support with which to initialize the sparse module. By
491506 default, the support is initialized to the empty set.
507+ closure: A closure to use to compute the loss and the gradients, see docstring
508+ of `fit_gpytorch_mll` for details.
509+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
510+ closure_kwargs: Additional arguments to pass to the `closure` function.
511+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
512+ By default, initializes the "options" sub-dictionary with `maxiter` and
513+ `ftol`, `gtol` values, unless specified.
492514
493515 Returns:
494516 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -516,6 +538,10 @@ def optimize_mll(mll):
516538 model_trace = model_trace ,
517539 reset_parameters = reset_parameters ,
518540 reset_dense_parameters = reset_dense_parameters ,
541+ # These are the args of the canonical mll fit routine
542+ closure = closure ,
543+ optimizer = optimizer ,
544+ closure_kwargs = closure_kwargs ,
519545 optimizer_kwargs = optimizer_kwargs ,
520546 )
521547
@@ -548,11 +574,15 @@ def backward_relevance_pursuit(
548574 sparse_module : RelevancePursuitMixin ,
549575 mll : ExactMarginalLogLikelihood ,
550576 sparsity_levels : list [int ] | None = None ,
551- optimizer_kwargs : dict [str , Any ] | None = None ,
552577 reset_parameters : bool = RESET_PARAMETERS ,
553578 reset_dense_parameters : bool = RESET_PARAMETERS ,
554579 record_model_trace : bool = True ,
555580 initial_support : list [int ] | None = None ,
581+ # fit_gpytorch_mll kwargs
582+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
583+ optimizer : Callable | None = None ,
584+ closure_kwargs : dict [str , Any ] | None = None ,
585+ optimizer_kwargs : dict [str , Any ] | None = None ,
556586) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
557587 """Backward Relevance Pursuit.
558588
@@ -583,9 +613,6 @@ def backward_relevance_pursuit(
583613 sparse_module: The relevance pursuit module.
584614 mll: The marginal likelihood, containing the model to optimize.
585615 sparsity_levels: The sparsity levels to expand the support to.
586- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587- By default, initializes the "options" sub-dictionary with `maxiter` and
588- `ftol`, `gtol` values, unless specified.
589616 reset_parameters: If true, initializes the sparse parameter to the all zeros
590617 after each iteration.
591618 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +621,13 @@ def backward_relevance_pursuit(
594621 record_model_trace: If true, records the model state after every iteration.
595622 initial_support: The support with which to initialize the sparse module. By
596623 default, the support is initialized to the full set.
624+ closure: A closure to use to compute the loss and the gradients, see docstring
625+ of `fit_gpytorch_mll` for details.
626+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
627+ closure_kwargs: Additional arguments to pass to the `closure` function.
628+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
629+ By default, initializes the "options" sub-dictionary with `maxiter` and
630+ `ftol`, `gtol` values, unless specified.
597631
598632 Returns:
599633 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +657,10 @@ def optimize_mll(mll):
623657 model_trace = model_trace ,
624658 reset_parameters = reset_parameters ,
625659 reset_dense_parameters = reset_dense_parameters ,
660+ # These are the args of the canonical mll fit routine
661+ closure = closure ,
662+ optimizer = optimizer ,
663+ closure_kwargs = closure_kwargs ,
626664 optimizer_kwargs = optimizer_kwargs ,
627665 )
628666
0 commit comments