2020import  math 
2121
2222from  abc  import  ABC , abstractmethod 
23- from  collections .abc  import  Callable 
23+ from  collections .abc  import  Callable ,  Sequence 
2424from  copy  import  copy , deepcopy 
2525from  functools  import  partial 
2626from  typing  import  Any , cast , Optional 
3535
3636MLL_ITER  =  10_000   # let's take convergence seriously 
3737MLL_TOL  =  1e-8 
38- RESET_PARAMETERS  =  False 
38+ RESET_PARAMETERS  =  True 
39+ RESET_DENSE_PARAMETERS  =  False 
3940
4041
4142class  RelevancePursuitMixin (ABC ):
4243    """Mixin class to convert between the sparse and dense representations of the 
43-     relevance pursuit models ' sparse parameters, as well as to compute the generalized 
44+     relevance pursuit modules ' sparse parameters, as well as to compute the generalized 
4445    support acquisition and support deletion criteria. 
4546    """ 
4647
@@ -251,19 +252,21 @@ def support_expansion(
251252        n : int  =  1 ,
252253        modifier : Callable [[Tensor ], Tensor ] |  None  =  None ,
253254    ) ->  bool :
254-         """Computes the indices of the features  that maximize the gradient of the sparse 
255+         """Computes the indices of the elements  that maximize the gradient of the sparse 
255256        parameter and that are not already in the support, and subsequently expands the 
256-         support to include the features  if their gradient is positive. 
257+         support to include the elements  if their gradient is positive. 
257258
258259        Args: 
259260            mll: The marginal likelihood, containing the model to optimize. 
260261                NOTE: Virtually all of the rest of the code is not specific to the 
261262                marginal likelihood optimization, so we could generalize this to work 
262263                with any objective. 
263-             n: The number of features to select. 
264-             modifier: A function that modifies the gradient of the inactive parameters 
264+             n: The maximum number of elements to select. NOTE: The actual number of 
265+                 elements that are added could be fewer if there are fewer than `n` 
266+                 elements with a positive gradient. 
267+             modifier: A function that modifies the gradient of the inactive elements 
265268                before computing the support expansion criterion. This can be used 
266-                 to select the maximum gradient magnitude for real-valued parameters  
269+                 to select the maximum gradient magnitude for real-valued elements  
267270                whose gradients are not non-negative, using modifier = torch.abs. 
268271
269272        Returns: 
@@ -354,15 +357,15 @@ def support_contraction(
354357        n : int  =  1 ,
355358        modifier : Callable [[Tensor ], Tensor ] |  None  =  None ,
356359    ) ->  bool :
357-         """Computes the indices of the features that have  the smallest coefficients , 
358-         and subsequently contracts the exlude  the features . 
360+         """Computes the indices of the elements with  the smallest magnitude , 
361+         and subsequently contracts the support by exluding  the elements . 
359362
360363        Args: 
361364            mll: The marginal likelihood, containing the model to optimize. 
362365                NOTE: Virtually all of the rest of the code is not specific to the 
363366                marginal likelihood optimization, so we could generalize this to work 
364367                with any objective. 
365-             n: The number of features  to select for removal. 
368+             n: The number of elements  to select for removal. 
366369            modifier: A function that modifies the parameter values before computing 
367370                the support contraction criterion. 
368371
@@ -395,7 +398,11 @@ def optimize_mll(
395398        mll : ExactMarginalLogLikelihood ,
396399        model_trace : list [Model ] |  None  =  None ,
397400        reset_parameters : bool  =  RESET_PARAMETERS ,
398-         reset_dense_parameters : bool  =  RESET_PARAMETERS ,
401+         reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
402+         # fit_gpytorch_mll kwargs 
403+         closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
404+         optimizer : Callable  |  None  =  None ,
405+         closure_kwargs : dict [str , Any ] |  None  =  None ,
399406        optimizer_kwargs : dict [str , Any ] |  None  =  None ,
400407    ):
401408        """Optimizes the marginal likelihood. 
@@ -410,6 +417,10 @@ def optimize_mll(
410417            reset_dense_parameters: If True, re-initializes the dense parameters, e.g. 
411418                other GP hyper-parameters that are *not* part of the Relevance Pursuit 
412419                module, to the initial values provided by their associated constraints. 
420+             closure: A closure to use to compute the loss and the gradients, see 
421+                 docstring of `fit_gpytorch_mll` for details. 
422+             optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
423+             closure_kwargs: Additional arguments to pass to the `closure` function. 
413424            optimizer_kwargs: A dictionary of keyword arguments for the optimizer. 
414425
415426        Returns: 
@@ -419,7 +430,6 @@ def optimize_mll(
419430            # this might be beneficial because the parameters can 
420431            # end up at a constraint boundary, which can anecdotally make 
421432            # it more difficult to move the newly added parameters. 
422-             # should we only do this after expansion? 
423433            with  torch .no_grad ():
424434                self .sparse_parameter .zero_ ()
425435
@@ -430,7 +440,13 @@ def optimize_mll(
430440        # NOTE: this function should never force the dense representation, because some 
431441        # models might never need it, and it would be inefficient. 
432442        self .to_sparse ()
433-         fit_gpytorch_mll (mll , optimizer_kwargs = optimizer_kwargs )
443+         mll  =  fit_gpytorch_mll (
444+             mll ,
445+             optimizer_kwargs = optimizer_kwargs ,
446+             closure = closure ,
447+             optimizer = optimizer ,
448+             closure_kwargs = closure_kwargs ,
449+         )
434450        if  model_trace  is  not   None :
435451            # need to record the full model here, rather than just the sparse parameter 
436452            # since other hyper-parameters are co-adapted to the sparse parameter. 
@@ -443,11 +459,15 @@ def forward_relevance_pursuit(
443459    sparse_module : RelevancePursuitMixin ,
444460    mll : ExactMarginalLogLikelihood ,
445461    sparsity_levels : list [int ] |  None  =  None ,
446-     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
447462    reset_parameters : bool  =  RESET_PARAMETERS ,
448-     reset_dense_parameters : bool  =  RESET_PARAMETERS ,
463+     reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
449464    record_model_trace : bool  =  True ,
450465    initial_support : list [int ] |  None  =  None ,
466+     # fit_gpytorch_mll kwargs 
467+     closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
468+     optimizer : Callable  |  None  =  None ,
469+     closure_kwargs : dict [str , Any ] |  None  =  None ,
470+     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
451471) ->  tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
452472    """Forward Relevance Pursuit. 
453473
@@ -478,9 +498,6 @@ def forward_relevance_pursuit(
478498        sparse_module: The relevance pursuit module. 
479499        mll: The marginal likelihood, containing the model to optimize. 
480500        sparsity_levels: The sparsity levels to expand the support to. 
481-         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
482-             By default, initializes the "options" sub-dictionary with `maxiter` and 
483-             `ftol`, `gtol` values, unless specified. 
484501        reset_parameters: If true, initializes the sparse parameter to the all zeros 
485502            after each iteration. 
486503        reset_dense_parameters: If true, re-initializes the dense parameters, e.g. 
@@ -489,6 +506,13 @@ def forward_relevance_pursuit(
489506        record_model_trace: If true, records the model state after every iteration. 
490507        initial_support: The support with which to initialize the sparse module. By 
491508            default, the support is initialized to the empty set. 
509+         closure: A closure to use to compute the loss and the gradients, see docstring 
510+             of `fit_gpytorch_mll` for details. 
511+         optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
512+         closure_kwargs: Additional arguments to pass to the `closure` function. 
513+         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
514+             By default, initializes the "options" sub-dictionary with `maxiter` and 
515+             `ftol`, `gtol` values, unless specified. 
492516
493517    Returns: 
494518        The relevance pursuit module after forward relevance pursuit optimization, and 
@@ -510,14 +534,17 @@ def forward_relevance_pursuit(
510534
511535    model_trace  =  [] if  record_model_trace  else  None 
512536
513-     def  optimize_mll (mll ):
514-         return  sparse_module .optimize_mll (
515-             mll = mll ,
516-             model_trace = model_trace ,
517-             reset_parameters = reset_parameters ,
518-             reset_dense_parameters = reset_dense_parameters ,
519-             optimizer_kwargs = optimizer_kwargs ,
520-         )
537+     optimize_mll  =  partial (
538+         sparse_module .optimize_mll ,
539+         model_trace = model_trace ,
540+         reset_parameters = reset_parameters ,
541+         reset_dense_parameters = reset_dense_parameters ,
542+         # These are the args of the canonical mll fit routine 
543+         closure = closure ,
544+         optimizer = optimizer ,
545+         closure_kwargs = closure_kwargs ,
546+         optimizer_kwargs = optimizer_kwargs ,
547+     )
521548
522549    # if sparsity levels contains the initial support, remove it 
523550    if  sparsity_levels [0 ] ==  len (sparse_module .support ):
@@ -548,11 +575,15 @@ def backward_relevance_pursuit(
548575    sparse_module : RelevancePursuitMixin ,
549576    mll : ExactMarginalLogLikelihood ,
550577    sparsity_levels : list [int ] |  None  =  None ,
551-     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
552578    reset_parameters : bool  =  RESET_PARAMETERS ,
553-     reset_dense_parameters : bool  =  RESET_PARAMETERS ,
579+     reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
554580    record_model_trace : bool  =  True ,
555581    initial_support : list [int ] |  None  =  None ,
582+     # fit_gpytorch_mll kwargs 
583+     closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
584+     optimizer : Callable  |  None  =  None ,
585+     closure_kwargs : dict [str , Any ] |  None  =  None ,
586+     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
556587) ->  tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
557588    """Backward Relevance Pursuit. 
558589
@@ -583,9 +614,6 @@ def backward_relevance_pursuit(
583614        sparse_module: The relevance pursuit module. 
584615        mll: The marginal likelihood, containing the model to optimize. 
585616        sparsity_levels: The sparsity levels to expand the support to. 
586-         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
587-             By default, initializes the "options" sub-dictionary with `maxiter` and 
588-             `ftol`, `gtol` values, unless specified. 
589617        reset_parameters: If true, initializes the sparse parameter to the all zeros 
590618            after each iteration. 
591619        reset_dense_parameters: If true, re-initializes the dense parameters, e.g. 
@@ -594,6 +622,13 @@ def backward_relevance_pursuit(
594622        record_model_trace: If true, records the model state after every iteration. 
595623        initial_support: The support with which to initialize the sparse module. By 
596624            default, the support is initialized to the full set. 
625+         closure: A closure to use to compute the loss and the gradients, see docstring 
626+             of `fit_gpytorch_mll` for details. 
627+         optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
628+         closure_kwargs: Additional arguments to pass to the `closure` function. 
629+         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
630+             By default, initializes the "options" sub-dictionary with `maxiter` and 
631+             `ftol`, `gtol` values, unless specified. 
597632
598633    Returns: 
599634        The relevance pursuit module after forward relevance pursuit optimization, and 
@@ -623,6 +658,10 @@ def optimize_mll(mll):
623658            model_trace = model_trace ,
624659            reset_parameters = reset_parameters ,
625660            reset_dense_parameters = reset_dense_parameters ,
661+             # These are the args of the canonical mll fit routine 
662+             closure = closure ,
663+             optimizer = optimizer ,
664+             closure_kwargs = closure_kwargs ,
626665            optimizer_kwargs = optimizer_kwargs ,
627666        )
628667
0 commit comments