2020import  math 
2121
2222from  abc  import  ABC , abstractmethod 
23- from  collections .abc  import  Callable 
23+ from  collections .abc  import  Callable ,  Sequence 
2424from  copy  import  copy , deepcopy 
2525from  functools  import  partial 
2626from  typing  import  Any , cast , Optional 
3535
3636MLL_ITER  =  10_000   # let's take convergence seriously 
3737MLL_TOL  =  1e-8 
38- RESET_PARAMETERS  =  False 
38+ RESET_PARAMETERS  =  True 
39+ RESET_DENSE_PARAMETERS  =  False 
3940
4041
4142class  RelevancePursuitMixin (ABC ):
4243    """Mixin class to convert between the sparse and dense representations of the 
43-     relevance pursuit models ' sparse parameters, as well as to compute the generalized 
44+     relevance pursuit modules ' sparse parameters, as well as to compute the generalized 
4445    support acquisition and support deletion criteria. 
4546    """ 
4647
@@ -395,7 +396,11 @@ def optimize_mll(
395396        mll : ExactMarginalLogLikelihood ,
396397        model_trace : list [Model ] |  None  =  None ,
397398        reset_parameters : bool  =  RESET_PARAMETERS ,
398-         reset_dense_parameters : bool  =  RESET_PARAMETERS ,
399+         reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
400+         # fit_gpytorch_mll kwargs 
401+         closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
402+         optimizer : Callable  |  None  =  None ,
403+         closure_kwargs : dict [str , Any ] |  None  =  None ,
399404        optimizer_kwargs : dict [str , Any ] |  None  =  None ,
400405    ):
401406        """Optimizes the marginal likelihood. 
@@ -410,6 +415,10 @@ def optimize_mll(
410415            reset_dense_parameters: If True, re-initializes the dense parameters, e.g. 
411416                other GP hyper-parameters that are *not* part of the Relevance Pursuit 
412417                module, to the initial values provided by their associated constraints. 
418+             closure: A closure to use to compute the loss and the gradients, see 
419+                 docstring of `fit_gpytorch_mll` for details. 
420+             optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
421+             closure_kwargs: Additional arguments to pass to the `closure` function. 
413422            optimizer_kwargs: A dictionary of keyword arguments for the optimizer. 
414423
415424        Returns: 
@@ -419,7 +428,6 @@ def optimize_mll(
419428            # this might be beneficial because the parameters can 
420429            # end up at a constraint boundary, which can anecdotally make 
421430            # it more difficult to move the newly added parameters. 
422-             # should we only do this after expansion? 
423431            with  torch .no_grad ():
424432                self .sparse_parameter .zero_ ()
425433
@@ -430,7 +438,13 @@ def optimize_mll(
430438        # NOTE: this function should never force the dense representation, because some 
431439        # models might never need it, and it would be inefficient. 
432440        self .to_sparse ()
433-         fit_gpytorch_mll (mll , optimizer_kwargs = optimizer_kwargs )
441+         mll  =  fit_gpytorch_mll (
442+             mll ,
443+             optimizer_kwargs = optimizer_kwargs ,
444+             closure = closure ,
445+             optimizer = optimizer ,
446+             closure_kwargs = closure_kwargs ,
447+         )
434448        if  model_trace  is  not   None :
435449            # need to record the full model here, rather than just the sparse parameter 
436450            # since other hyper-parameters are co-adapted to the sparse parameter. 
@@ -443,11 +457,15 @@ def forward_relevance_pursuit(
443457    sparse_module : RelevancePursuitMixin ,
444458    mll : ExactMarginalLogLikelihood ,
445459    sparsity_levels : list [int ] |  None  =  None ,
446-     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
447460    reset_parameters : bool  =  RESET_PARAMETERS ,
448-     reset_dense_parameters : bool  =  RESET_PARAMETERS ,
461+     reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
449462    record_model_trace : bool  =  True ,
450463    initial_support : list [int ] |  None  =  None ,
464+     # fit_gpytorch_mll kwargs 
465+     closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
466+     optimizer : Callable  |  None  =  None ,
467+     closure_kwargs : dict [str , Any ] |  None  =  None ,
468+     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
451469) ->  tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
452470    """Forward Relevance Pursuit. 
453471
@@ -478,9 +496,6 @@ def forward_relevance_pursuit(
478496        sparse_module: The relevance pursuit module. 
479497        mll: The marginal likelihood, containing the model to optimize. 
480498        sparsity_levels: The sparsity levels to expand the support to. 
481-         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
482-             By default, initializes the "options" sub-dictionary with `maxiter` and 
483-             `ftol`, `gtol` values, unless specified. 
484499        reset_parameters: If true, initializes the sparse parameter to the all zeros 
485500            after each iteration. 
486501        reset_dense_parameters: If true, re-initializes the dense parameters, e.g. 
@@ -489,6 +504,13 @@ def forward_relevance_pursuit(
489504        record_model_trace: If true, records the model state after every iteration. 
490505        initial_support: The support with which to initialize the sparse module. By 
491506            default, the support is initialized to the empty set. 
507+         closure: A closure to use to compute the loss and the gradients, see docstring 
508+             of `fit_gpytorch_mll` for details. 
509+         optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
510+         closure_kwargs: Additional arguments to pass to the `closure` function. 
511+         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
512+             By default, initializes the "options" sub-dictionary with `maxiter` and 
513+             `ftol`, `gtol` values, unless specified. 
492514
493515    Returns: 
494516        The relevance pursuit module after forward relevance pursuit optimization, and 
@@ -510,14 +532,17 @@ def forward_relevance_pursuit(
510532
511533    model_trace  =  [] if  record_model_trace  else  None 
512534
513-     def  optimize_mll (mll ):
514-         return  sparse_module .optimize_mll (
515-             mll = mll ,
516-             model_trace = model_trace ,
517-             reset_parameters = reset_parameters ,
518-             reset_dense_parameters = reset_dense_parameters ,
519-             optimizer_kwargs = optimizer_kwargs ,
520-         )
535+     optimize_mll  =  partial (
536+         sparse_module .optimize_mll ,
537+         model_trace = model_trace ,
538+         reset_parameters = reset_parameters ,
539+         reset_dense_parameters = reset_dense_parameters ,
540+         # These are the args of the canonical mll fit routine 
541+         closure = closure ,
542+         optimizer = optimizer ,
543+         closure_kwargs = closure_kwargs ,
544+         optimizer_kwargs = optimizer_kwargs ,
545+     )
521546
522547    # if sparsity levels contains the initial support, remove it 
523548    if  sparsity_levels [0 ] ==  len (sparse_module .support ):
@@ -548,11 +573,15 @@ def backward_relevance_pursuit(
548573    sparse_module : RelevancePursuitMixin ,
549574    mll : ExactMarginalLogLikelihood ,
550575    sparsity_levels : list [int ] |  None  =  None ,
551-     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
552576    reset_parameters : bool  =  RESET_PARAMETERS ,
553-     reset_dense_parameters : bool  =  RESET_PARAMETERS ,
577+     reset_dense_parameters : bool  =  RESET_DENSE_PARAMETERS ,
554578    record_model_trace : bool  =  True ,
555579    initial_support : list [int ] |  None  =  None ,
580+     # fit_gpytorch_mll kwargs 
581+     closure : Callable [[], tuple [Tensor , Sequence [Tensor  |  None ]]] |  None  =  None ,
582+     optimizer : Callable  |  None  =  None ,
583+     closure_kwargs : dict [str , Any ] |  None  =  None ,
584+     optimizer_kwargs : dict [str , Any ] |  None  =  None ,
556585) ->  tuple [RelevancePursuitMixin , Optional [list [Model ]]]:
557586    """Backward Relevance Pursuit. 
558587
@@ -583,9 +612,6 @@ def backward_relevance_pursuit(
583612        sparse_module: The relevance pursuit module. 
584613        mll: The marginal likelihood, containing the model to optimize. 
585614        sparsity_levels: The sparsity levels to expand the support to. 
586-         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
587-             By default, initializes the "options" sub-dictionary with `maxiter` and 
588-             `ftol`, `gtol` values, unless specified. 
589615        reset_parameters: If true, initializes the sparse parameter to the all zeros 
590616            after each iteration. 
591617        reset_dense_parameters: If true, re-initializes the dense parameters, e.g. 
@@ -594,6 +620,13 @@ def backward_relevance_pursuit(
594620        record_model_trace: If true, records the model state after every iteration. 
595621        initial_support: The support with which to initialize the sparse module. By 
596622            default, the support is initialized to the full set. 
623+         closure: A closure to use to compute the loss and the gradients, see docstring 
624+             of `fit_gpytorch_mll` for details. 
625+         optimizer: The numerical optimizer, see docstring of `fit_gpytorch_mll`. 
626+         closure_kwargs: Additional arguments to pass to the `closure` function. 
627+         optimizer_kwargs: A dictionary of keyword arguments to pass to the optimizer. 
628+             By default, initializes the "options" sub-dictionary with `maxiter` and 
629+             `ftol`, `gtol` values, unless specified. 
597630
598631    Returns: 
599632        The relevance pursuit module after forward relevance pursuit optimization, and 
@@ -623,6 +656,10 @@ def optimize_mll(mll):
623656            model_trace = model_trace ,
624657            reset_parameters = reset_parameters ,
625658            reset_dense_parameters = reset_dense_parameters ,
659+             # These are the args of the canonical mll fit routine 
660+             closure = closure ,
661+             optimizer = optimizer ,
662+             closure_kwargs = closure_kwargs ,
626663            optimizer_kwargs = optimizer_kwargs ,
627664        )
628665
0 commit comments