2020import math
2121
2222from abc import ABC , abstractmethod
23- from collections .abc import Callable
23+ from collections .abc import Callable , Sequence
2424from copy import copy , deepcopy
2525from functools import partial
2626from typing import Any , cast , Optional
3535
3636MLL_ITER = 10_000 # let's take convergence seriously
3737MLL_TOL = 1e-8
38- RESET_PARAMETERS = False
38+ RESET_PARAMETERS = True
39+ RESET_DENSE_PARAMETERS = False
3940
4041
4142class RelevancePursuitMixin (ABC ):
4243 """Mixin class to convert between the sparse and dense representations of the
43- relevance pursuit models ' sparse parameters, as well as to compute the generalized
44+ relevance pursuit modules ' sparse parameters, as well as to compute the generalized
4445 support acquisition and support deletion criteria.
4546 """
4647
@@ -251,19 +252,21 @@ def support_expansion(
251252 n : int = 1 ,
252253 modifier : Callable [[Tensor ], Tensor ] | None = None ,
253254 ) -> bool :
254- """Computes the indices of the features that maximize the gradient of the sparse
255+ """Computes the indices of the elements that maximize the gradient of the sparse
255256 parameter and that are not already in the support, and subsequently expands the
256- support to include the features if their gradient is positive.
257+ support to include the elements if their gradient is positive.
257258
258259 Args:
259260 mll: The marginal likelihood, containing the model to optimize.
260261 NOTE: Virtually all of the rest of the code is not specific to the
261262 marginal likelihood optimization, so we could generalize this to work
262263 with any objective.
263- n: The number of features to select.
264- modifier: A function that modifies the gradient of the inactive parameters
264+ n: The maximum number of elements to select. NOTE: The actual number of
265+ elements that are added could be fewer if there are fewer than `n`
266+ elements with a positive gradient.
267+ modifier: A function that modifies the gradient of the inactive elements
265268 before computing the support expansion criterion. This can be used
266- to select the maximum gradient magnitude for real-valued parameters
269+ to select the maximum gradient magnitude for real-valued elements
267270 whose gradients are not non-negative, using modifier = torch.abs.
268271
269272 Returns:
@@ -354,15 +357,15 @@ def support_contraction(
354357 n : int = 1 ,
355358 modifier : Callable [[Tensor ], Tensor ] | None = None ,
356359 ) -> bool :
357- """Computes the indices of the features that have the smallest coefficients ,
358- and subsequently contracts the exlude the features .
360+ """Computes the indices of the elements with the smallest magnitude ,
361+ and subsequently contracts the support by exluding the elements .
359362
360363 Args:
361364 mll: The marginal likelihood, containing the model to optimize.
362365 NOTE: Virtually all of the rest of the code is not specific to the
363366 marginal likelihood optimization, so we could generalize this to work
364367 with any objective.
365- n: The number of features to select for removal.
368+ n: The number of elements to select for removal.
366369 modifier: A function that modifies the parameter values before computing
367370 the support contraction criterion.
368371
@@ -395,7 +398,11 @@ def optimize_mll(
395398 mll : ExactMarginalLogLikelihood ,
396399 model_trace : list [Model ] | None = None ,
397400 reset_parameters : bool = RESET_PARAMETERS ,
398- reset_dense_parameters : bool = RESET_PARAMETERS ,
401+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
402+ # fit_gpytorch_mll kwargs
403+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
404+ optimizer : Callable | None = None ,
405+ closure_kwargs : dict [str , Any ] | None = None ,
399406 optimizer_kwargs : dict [str , Any ] | None = None ,
400407 ):
401408 """Optimizes the marginal likelihood.
@@ -410,6 +417,10 @@ def optimize_mll(
410417 reset_dense_parameters: If True, re-initializes the dense parameters, e.g.
411418 other GP hyper-parameters that are *not* part of the Relevance Pursuit
412419 module, to the initial values provided by their associated constraints.
420+ closure: A closure to use to compute the loss and the gradients, see
421+ docstring of `fit_gpytorch_mll` for details.
422+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
423+ closure_kwargs: Additional arguments to pass to the `closure` function.
413424 optimizer_kwargs: A dictionary of keyword arguments for the optimizer.
414425
415426 Returns:
@@ -419,7 +430,6 @@ def optimize_mll(
419430 # this might be beneficial because the parameters can
420431 # end up at a constraint boundary, which can anecdotally make
421432 # it more difficult to move the newly added parameters.
422- # should we only do this after expansion?
423433 with torch .no_grad ():
424434 self .sparse_parameter .zero_ ()
425435
@@ -430,7 +440,13 @@ def optimize_mll(
430440 # NOTE: this function should never force the dense representation, because some
431441 # models might never need it, and it would be inefficient.
432442 self .to_sparse ()
433- fit_gpytorch_mll (mll , optimizer_kwargs = optimizer_kwargs )
443+ mll = fit_gpytorch_mll (
444+ mll ,
445+ optimizer_kwargs = optimizer_kwargs ,
446+ closure = closure ,
447+ optimizer = optimizer ,
448+ closure_kwargs = closure_kwargs ,
449+ )
434450 if model_trace is not None :
435451 # need to record the full model here, rather than just the sparse parameter
436452 # since other hyper-parameters are co-adapted to the sparse parameter.
@@ -443,11 +459,15 @@ def forward_relevance_pursuit(
443459 sparse_module : RelevancePursuitMixin ,
444460 mll : ExactMarginalLogLikelihood ,
445461 sparsity_levels : list [int ] | None = None ,
446- optimizer_kwargs : dict [str , Any ] | None = None ,
447462 reset_parameters : bool = RESET_PARAMETERS ,
448- reset_dense_parameters : bool = RESET_PARAMETERS ,
463+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
449464 record_model_trace : bool = True ,
450465 initial_support : list [int ] | None = None ,
466+ # fit_gpytorch_mll kwargs
467+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
468+ optimizer : Callable | None = None ,
469+ closure_kwargs : dict [str , Any ] | None = None ,
470+ optimizer_kwargs : dict [str , Any ] | None = None ,
451471) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
452472 """Forward Relevance Pursuit.
453473
@@ -478,9 +498,6 @@ def forward_relevance_pursuit(
478498 sparse_module: The relevance pursuit module.
479499 mll: The marginal likelihood, containing the model to optimize.
480500 sparsity_levels: The sparsity levels to expand the support to.
481- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
482- By default, initializes the "options" sub-dictionary with `maxiter` and
483- `ftol`, `gtol` values, unless specified.
484501 reset_parameters: If true, initializes the sparse parameter to the all zeros
485502 after each iteration.
486503 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -489,6 +506,13 @@ def forward_relevance_pursuit(
489506 record_model_trace: If true, records the model state after every iteration.
490507 initial_support: The support with which to initialize the sparse module. By
491508 default, the support is initialized to the empty set.
509+ closure: A closure to use to compute the loss and the gradients, see docstring
510+ of `fit_gpytorch_mll` for details.
511+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
512+ closure_kwargs: Additional arguments to pass to the `closure` function.
513+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
514+ By default, initializes the "options" sub-dictionary with `maxiter` and
515+ `ftol`, `gtol` values, unless specified.
492516
493517 Returns:
494518 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -510,14 +534,17 @@ def forward_relevance_pursuit(
510534
511535 model_trace = [] if record_model_trace else None
512536
513- def optimize_mll (mll ):
514- return sparse_module .optimize_mll (
515- mll = mll ,
516- model_trace = model_trace ,
517- reset_parameters = reset_parameters ,
518- reset_dense_parameters = reset_dense_parameters ,
519- optimizer_kwargs = optimizer_kwargs ,
520- )
537+ optimize_mll = partial (
538+ sparse_module .optimize_mll ,
539+ model_trace = model_trace ,
540+ reset_parameters = reset_parameters ,
541+ reset_dense_parameters = reset_dense_parameters ,
542+ # These are the args of the canonical mll fit routine
543+ closure = closure ,
544+ optimizer = optimizer ,
545+ closure_kwargs = closure_kwargs ,
546+ optimizer_kwargs = optimizer_kwargs ,
547+ )
521548
522549 # if sparsity levels contains the initial support, remove it
523550 if sparsity_levels [0 ] == len (sparse_module .support ):
@@ -548,11 +575,15 @@ def backward_relevance_pursuit(
548575 sparse_module : RelevancePursuitMixin ,
549576 mll : ExactMarginalLogLikelihood ,
550577 sparsity_levels : list [int ] | None = None ,
551- optimizer_kwargs : dict [str , Any ] | None = None ,
552578 reset_parameters : bool = RESET_PARAMETERS ,
553- reset_dense_parameters : bool = RESET_PARAMETERS ,
579+ reset_dense_parameters : bool = RESET_DENSE_PARAMETERS ,
554580 record_model_trace : bool = True ,
555581 initial_support : list [int ] | None = None ,
582+ # fit_gpytorch_mll kwargs
583+ closure : Callable [[], tuple [Tensor , Sequence [Tensor | None ]]] | None = None ,
584+ optimizer : Callable | None = None ,
585+ closure_kwargs : dict [str , Any ] | None = None ,
586+ optimizer_kwargs : dict [str , Any ] | None = None ,
556587) -> tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
557588 """Backward Relevance Pursuit.
558589
@@ -583,9 +614,6 @@ def backward_relevance_pursuit(
583614 sparse_module: The relevance pursuit module.
584615 mll: The marginal likelihood, containing the model to optimize.
585616 sparsity_levels: The sparsity levels to expand the support to.
586- optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
587- By default, initializes the "options" sub-dictionary with `maxiter` and
588- `ftol`, `gtol` values, unless specified.
589617 reset_parameters: If true, initializes the sparse parameter to the all zeros
590618 after each iteration.
591619 reset_dense_parameters: If true, re-initializes the dense parameters, e.g.
@@ -594,6 +622,13 @@ def backward_relevance_pursuit(
594622 record_model_trace: If true, records the model state after every iteration.
595623 initial_support: The support with which to initialize the sparse module. By
596624 default, the support is initialized to the full set.
625+ closure: A closure to use to compute the loss and the gradients, see docstring
626+ of `fit_gpytorch_mll` for details.
627+ optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`.
628+ closure_kwargs: Additional arguments to pass to the `closure` function.
629+ optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer.
630+ By default, initializes the "options" sub-dictionary with `maxiter` and
631+ `ftol`, `gtol` values, unless specified.
597632
598633 Returns:
599634 The relevance pursuit module after forward relevance pursuit optimization, and
@@ -623,6 +658,10 @@ def optimize_mll(mll):
623658 model_trace = model_trace ,
624659 reset_parameters = reset_parameters ,
625660 reset_dense_parameters = reset_dense_parameters ,
661+ # These are the args of the canonical mll fit routine
662+ closure = closure ,
663+ optimizer = optimizer ,
664+ closure_kwargs = closure_kwargs ,
626665 optimizer_kwargs = optimizer_kwargs ,
627666 )
628667
0 commit comments