2020import math
2121
2222from abc import ABC , abstractmethod
23- from collections .abc import Callable
23+ from collections .abc import Callable , Sequence
2424from copy import copy , deepcopy
2525from functools import partial
2626from typing import Any , cast , Optional
3535
3636MLL_ITER = 10_000 # let's take convergence seriously
3737MLL_TOL = 1e-8
38- RESET_PARAMETERS = False
38+ RESET_PARAMETERS = True
39+ RESET_DENSE_PARAMETERS = False
3940
4041
4142class RelevancePursuitMixin (ABC ):
4243 """Mixin class to convert between the sparse and dense representations of the
43- relevance pursuit models ' sparse parameters, as well as to compute the generalized
44+ relevance pursuit modules ' sparse parameters, as well as to compute the generalized
4445 support acquisition and support deletion criteria.
4546 """
4647
@@ -395,7 +396,11 @@ def optimize_mll(
395396 mll : ExactMarginalLogLikelihood ,
396397 model_trace : list [Model ] | None = None ,
397398 reset_parameters : bool = RESET_PARAMETERS ,
398- reset_dense_parameters : bool = RESET_PARAMETERS ,
399+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
400+ # fit_gpytorch_mll kwargs
401+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
402+ optimizer : Callable | None = None ,
403+ closure_kwargs : dict [str , Any ] | None = None ,
399404 optimizer_kwargs : dict [str , Any ] | None = None ,
400405 ):
401406 """Optimizes the marginal likelihood.
@@ -410,6 +415,10 @@ def optimize_mll(
410415 reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411416 other GP hyper-parameters that are *not* part of the Relevance Pursuit
412417 module, to the initial values provided by their associated constraints.
418+ closure: A closure to use to compute the loss and the gradients, see
419+ docstring of `fit_gpytorch_mll` for details.
420+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
421+ closure_kwargs: Additional arguments to pass to the `closure` function.
413422 optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414423
415424 Returns:
@@ -419,7 +428,6 @@ def optimize_mll(
419428 # this might be beneficial because the parameters can
420429 # end up at a constraint boundary, which can anecdotally make
421430 # it more difficult to move the newly added parameters.
422- # should we only do this after expansion?
423431 with torch .no_grad ():
424432 self .sparse_parameter .zero_ ()
425433
@@ -430,7 +438,13 @@ def optimize_mll(
430438 # NOTE: this function should never force the dense representation, because some
431439 # models might never need it, and it would be inefficient.
432440 self .to_sparse ()
433- fit_gpytorch_mll (mll , optimizer_kwargs = optimizer_kwargs )
441+ mll = fit_gpytorch_mll (
442+ mll ,
443+ optimizer_kwargs = optimizer_kwargs ,
444+ closure = closure ,
445+ optimizer = optimizer ,
446+ closure_kwargs = closure_kwargs ,
447+ )
434448 if model_trace is not None :
435449 # need to record the full model here, rather than just the sparse parameter
436450 # since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +457,15 @@ def forward_relevance_pursuit(
443457 sparse_module : RelevancePursuitMixin ,
444458 mll : ExactMarginalLogLikelihood ,
445459 sparsity_levels : list [int ] | None = None ,
446- optimizer_kwargs : dict [str , Any ] | None = None ,
447460 reset_parameters : bool = RESET_PARAMETERS ,
448- reset_dense_parameters : bool = RESET_PARAMETERS ,
461+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
449462 record_model_trace : bool = True ,
450463 initial_support : list [int ] | None = None ,
464+ # fit_gpytorch_mll kwargs
465+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
466+ optimizer : Callable | None = None ,
467+ closure_kwargs : dict [str , Any ] | None = None ,
468+ optimizer_kwargs : dict [str , Any ] | None = None ,
451469) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
452470 """Forward Relevance Pursuit.
453471
@@ -478,9 +496,6 @@ def forward_relevance_pursuit(
478496 sparse_module: The relevance pursuit module.
479497 mll: The marginal likelihood, containing the model to optimize.
480498 sparsity_levels: The sparsity levels to expand the support to.
481- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482- By default, initializes the "options" sub-dictionary with `maxiter` and
483- `ftol`, `gtol` values, unless specified.
484499 reset_parameters: If true, initializes the sparse parameter to the all zeros
485500 after each iteration.
486501 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +504,13 @@ def forward_relevance_pursuit(
489504 record_model_trace: If true, records the model state after every iteration.
490505 initial_support: The support with which to initialize the sparse module. By
491506 default, the support is initialized to the empty set.
507+ closure: A closure to use to compute the loss and the gradients, see docstring
508+ of `fit_gpytorch_mll` for details.
509+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
510+ closure_kwargs: Additional arguments to pass to the `closure` function.
511+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
512+ By default, initializes the "options" sub-dictionary with `maxiter` and
513+ `ftol`, `gtol` values, unless specified.
492514
493515 Returns:
494516 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -510,14 +532,17 @@ def forward_relevance_pursuit(
510532
511533 model_trace = [] if record_model_trace else None
512534
513- def optimize_mll (mll ):
514- return sparse_module .optimize_mll (
515- mll = mll ,
516- model_trace = model_trace ,
517- reset_parameters = reset_parameters ,
518- reset_dense_parameters = reset_dense_parameters ,
519- optimizer_kwargs = optimizer_kwargs ,
520- )
535+ optimize_mll = partial (
536+ sparse_module .optimize_mll ,
537+ model_trace = model_trace ,
538+ reset_parameters = reset_parameters ,
539+ reset_dense_parameters = reset_dense_parameters ,
540+ # These are the args of the canonical mll fit routine
541+ closure = closure ,
542+ optimizer = optimizer ,
543+ closure_kwargs = closure_kwargs ,
544+ optimizer_kwargs = optimizer_kwargs ,
545+ )
521546
522547 # if sparsity levels contains the initial support, remove it
523548 if sparsity_levels [0 ] == len (sparse_module .support ):
@@ -548,11 +573,15 @@ def backward_relevance_pursuit(
548573 sparse_module : RelevancePursuitMixin ,
549574 mll : ExactMarginalLogLikelihood ,
550575 sparsity_levels : list [int ] | None = None ,
551- optimizer_kwargs : dict [str , Any ] | None = None ,
552576 reset_parameters : bool = RESET_PARAMETERS ,
553- reset_dense_parameters : bool = RESET_PARAMETERS ,
577+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
554578 record_model_trace : bool = True ,
555579 initial_support : list [int ] | None = None ,
580+ # fit_gpytorch_mll kwargs
581+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
582+ optimizer : Callable | None = None ,
583+ closure_kwargs : dict [str , Any ] | None = None ,
584+ optimizer_kwargs : dict [str , Any ] | None = None ,
556585) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
557586 """Backward Relevance Pursuit.
558587
@@ -583,9 +612,6 @@ def backward_relevance_pursuit(
583612 sparse_module: The relevance pursuit module.
584613 mll: The marginal likelihood, containing the model to optimize.
585614 sparsity_levels: The sparsity levels to expand the support to.
586- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587- By default, initializes the "options" sub-dictionary with `maxiter` and
588- `ftol`, `gtol` values, unless specified.
589615 reset_parameters: If true, initializes the sparse parameter to the all zeros
590616 after each iteration.
591617 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +620,13 @@ def backward_relevance_pursuit(
594620 record_model_trace: If true, records the model state after every iteration.
595621 initial_support: The support with which to initialize the sparse module. By
596622 default, the support is initialized to the full set.
623+ closure: A closure to use to compute the loss and the gradients, see docstring
624+ of `fit_gpytorch_mll` for details.
625+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
626+ closure_kwargs: Additional arguments to pass to the `closure` function.
627+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
628+ By default, initializes the "options" sub-dictionary with `maxiter` and
629+ `ftol`, `gtol` values, unless specified.
597630
598631 Returns:
599632 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +656,10 @@ def optimize_mll(mll):
623656 model_trace = model_trace ,
624657 reset_parameters = reset_parameters ,
625658 reset_dense_parameters = reset_dense_parameters ,
659+ # These are the args of the canonical mll fit routine
660+ closure = closure ,
661+ optimizer = optimizer ,
662+ closure_kwargs = closure_kwargs ,
626663 optimizer_kwargs = optimizer_kwargs ,
627664 )
628665
0 commit comments