diff --git a/src/mssm/models.py b/src/mssm/models.py index b702aa8..2291cd6 100644 --- a/src/mssm/models.py +++ b/src/mssm/models.py @@ -8,6 +8,7 @@ from .src.python.gamm_solvers import solve_gamm_sparse,mp,repeat,tqdm,cpp_cholP,apply_eigen_perm,compute_Linv,solve_gamm_sparse2 from .src.python.terms import TermType,GammTerm,i,f,fs,irf,l,li,ri,rs from .src.python.penalties import PenType +from .src.python.utils import sample_MVN ##################################### Base Class ##################################### @@ -329,12 +330,87 @@ def fit(self,maxiter=50,conv_tol=1e-7,extend_lambda=True,control_lambda=True,exc ##################################### Prediction ##################################### - def predict(self, use_terms, n_dat,alpha=0.05,ci=False): + def sample_post(self,n_ps,use_post=None,deviations=False,seed=None): + """ + Obtain ``n_ps`` samples from posterior [\boldsymbol{\beta} - \hat{\boldsymbol{\beta}}] | y,\lambda ~ N(0,V), + where V is [X.T@X + S_\lambda]^{-1}*scale (see Wood, 2017; section 6.10). To obtain samples for \boldsymbol{\beta}, + simply set ``deviations`` to false. + + see ``sample_MVN`` in ``mssm.src.python.utils.py`` for more details. + + References: + - Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.). + + Parameters: + + :param use_post: The indices corresponding to coefficients for which to actually obtain samples. + """ + if deviations: + post = sample_MVN(n_ps,0,self.__scale,P=None,L=None,LI=self.lvi,use=use_post,seed=seed) + else: + post = sample_MVN(n_ps,self.__coef,self.__scale,P=None,L=None,LI=self.lvi,use=use_post,seed=seed) + + return post + + def __adjust_CI(self,n_ps,b,predi_mat,use_terms,alpha,seed): + """ + Adjusts point-wise CI to behave like whole-interval (based on Wood, 2017; section 6.10.2 and Simpson, 2016): + + self.coef +- b gives point-wise interval, so for interval to be whole-interval + 1-alpha% of posterior samples should fall completely within these boundaries. + + From section 6.10 in Wood (2017) we have that *coef | y, lambda ~ N(coef,V), where V is self.lvi.T @ self.lvi * self.__scale. + Implication is that deviations [*coef - coef] | y, lambda ~ N(0,V). In line with definition above, 1-alpha% of + predi_mat@[*coef - coef] should fall within [b,-b]. Wood (2017) suggests to find a so that [a*b,a*-b] achieves this. + + To do this, we find a for every predi_mat@[*coef - coef] and then select the final one so that 1-alpha% of samples had an equal or lower + one. The consequence: 1-alpha% of samples drawn should fall completely within the modified boundaries. + + References: + - Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.). + - Simpson, G. (2016). Simultaneous intervals for smooths revisited. + """ + use_post = None + if not use_terms is None: + # If we have many random factor levels, but want to make predictions only + # for fixed effects, it's wasteful to sample all coefficients from posterior. + # The code below performs a selection of the coefficients to be sampled. + use_post = predi_mat.sum(axis=0) != 0 + use_post = np.arange(0,predi_mat.shape[1])[use_post] + + # Sample deviations [*coef - coef] from posterior of GAMM + post = self.sample_post(n_ps,use_post,deviations=True,seed=seed) + + # To make computations easier take the abs of predi_mat@[*coef - coef], because [b,-b] is symmetric we can + # simply check whether abs(predi_mat@[*coef - coef]) < b by computing abs(predi_mat@[*coef - coef])/b. The max of + # this ratio, over rows of predi_mat, is a for this sample. If a<=1, no extension is necessary for this series. + if use_post is None: + fpost = np.abs(predi_mat@post) + else: + fpost = np.abs(predi_mat[:,use_post]@post) + + # Compute ratio between abs(predi_mat@[*coef - coef])/b for every sample. + fpost = fpost / b[:,None] + + # Then compute max of this ratio, over rows of predi_mat, for every sample. np.max(fpost,axis=0) now is a vector + # with n_ps elements, holding for each sample the multiplicative adjustment a, necessary to ensure that predi_mat@[*coef - coef] + # falls completely between [a*b,a*-b]. + # The final multiplicative adjustment bmadq is selected from this vector to be high enough so that in 1-(alpha) simulations + # we have an equal or lower a. + bmadq = np.quantile(np.max(fpost,axis=0),1-alpha) + + # Then adjust b + b *= bmadq + + return b + + def predict(self, use_terms, n_dat,alpha=0.05,ci=False,whole_interval=False,n_ps=10000,seed=None): """ Make a prediction using the fitted model for new data ``n_dat`` using only the terms indexed by ``use_terms``. References: - Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.). + - Simpson, G. (2016). Simultaneous intervals for smooths revisited. Parameters: @@ -401,11 +477,18 @@ def predict(self, use_terms, n_dat,alpha=0.05,ci=False): c = predi_mat @ self.lvi.T @ self.lvi * self.__scale @ predi_mat.T c = c.diagonal() b = scp.stats.norm.ppf(1-(alpha/2)) * np.sqrt(c) + + # Whole-interval CI (section 6.10.2 in Wood, 2017), the same idea was also + # explored by Simpson (2016) who performs very similar computations to compute + # such intervals. See __adjust_CI function. + if whole_interval: + b = self.__adjust_CI(n_ps,b,predi_mat,use_terms,alpha,seed) + return pred,predi_mat,b return pred,predi_mat,None - def predict_diff(self,dat1,dat2,use_terms,alpha=0.05): + def predict_diff(self,dat1,dat2,use_terms,alpha=0.05,whole_interval=False,n_ps=10000,seed=None): """ Get the difference in the predictions for two datasets. Useful to compare a smooth estimated for one level of a factor to the smooth estimated for another level of a factor. In that case, ``dat1`` and @@ -413,6 +496,7 @@ def predict_diff(self,dat1,dat2,use_terms,alpha=0.05): References: - Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.). + - Simpson, G. (2016). Simultaneous intervals for smooths revisited. - ``get_difference`` function from ``itsadug`` R-package: https://rdrr.io/cran/itsadug/man/get_difference.html Parameters: @@ -447,6 +531,12 @@ def predict_diff(self,dat1,dat2,use_terms,alpha=0.05): c = c.diagonal() b = scp.stats.norm.ppf(1-(alpha/2)) * np.sqrt(c) + # Whole-interval CI (section 6.10.2 in Wood, 2017), the same idea was also + # explored by Simpson (2016) who performs very similar computations to compute + # such intervals. See __adjust_CI function. + if whole_interval: + b = self.__adjust_CI(n_ps,b,pmat_diff,use_terms,alpha,seed) + return diff,b class sMsGAMM(MSSM): diff --git a/src/mssm/src/python/formula.py b/src/mssm/src/python/formula.py index 9057f1c..e597b77 100644 --- a/src/mssm/src/python/formula.py +++ b/src/mssm/src/python/formula.py @@ -113,7 +113,7 @@ def reparam(X,S,cov,option=1,n_bins=30,QR=False,identity=False,scale=False): # Now decompose X = Q @ R if QR: _,R = scp.linalg.qr(X.toarray(),mode='economic') - R = scp.sparse.csc_array(R) + R = scp.sparse.csr_array(R) else: XX = (X.T @ X).tocsc() diff --git a/src/mssm/src/python/utils.py b/src/mssm/src/python/utils.py new file mode 100644 index 0000000..fedb730 --- /dev/null +++ b/src/mssm/src/python/utils.py @@ -0,0 +1,91 @@ +import numpy as np +import scipy as scp +import math +import warnings +from ..python.gamm_solvers import cpp_backsolve_tr + +def sample_MVN(n,mu,scale,P,L,LI=None,use=None,seed=None): + """ + Draw n samples x from multivariate normal with mean mu and covariance matrix Sigma so that Sigma/scale = LI.T@LI, LI = L^{-1}, and + finally L@L.T = {Sigma/scale}^{-1}. In other words, L*(1/scale)^{0.5} is the cholesky for the precision matrix corresponding to Sigma. + Notably, L (and LI) have actually be computed for P@[X.T@X+S_\lambda]@P.T (see Wood \& Fasiolo, 2017), hence for sampling we need to correct + for permutation matrix ``P``. if ``LI`` is provided, then ``P`` can be omitted and is assumed to have been applied to ``LI already`` + + Used to sample the uncorrected posterior \beta|y,\lambda ~ N(\boldsymbol{\beta},(X.T@X+S_\lambda)^{-1}\phi) for a GAMM (see Wood, 2017). + + Based on section 7.4 in Gentle (2009), assuming Sigma is p*p and covariance matrix of uncorrected posterior: + x = mu + P.T@LI.T*scale^{0.5}@z where z_i ~ N(0,1) for all i = 1,...,p + + Notably, we can rely on the fact of equivalence that: + L.T*(1/scale)^{0.5} @ P@x = z + + ...and then first solve for y in: + L.T*(1/scale)^{0.5} @ y = z + + ...followed by computing: + y = P@x + x = P.T@y + + + The latter allows to avoid forming L^{-1} (which unlike L might not benefit from the sparsity preserving permutation P). Hence, if ``LI is None``, + ``L`` will be used for sampling. + + Often we care only about a handfull of elements in mu (usually the first ones corresponding to "fixed effects'" in a GAMM). In that case we + can generate x only for this sub-set of interest by only using a row-block of L/LI (all columns remain). Argument ``use`` can be a Numpy array + containg the indices of elements in mu that should be sampled. Because this only works efficiently when ``LI`` is available an error is raised + when ``not use is None and LI is None``. + + If ``mu`` is set to any integer (i.e., not a Numpy array/list) it is treated as 0. + + References: + - Wood, S. N. (2017). Generalized Additive Models: An Introduction with R, Second Edition (2nd ed.). + - Gentle, J. (2009). Computational Statistics. + """ + if L is None and LI is None: + raise ValueError("Either ``L`` or ``LI`` have to be provided.") + + if not L is None and not LI is None: + warnings.warn("Both ``L`` and ``LI`` were provided, will rely on ``LI``.") + + if not L is None and LI is None and P is None: + raise ValueError("When sampling with ``L`` ``P`` must be provided.") + + if not use is None and LI is None: + raise ValueError("If ``use`` is not None ``LI`` must be provided.") + + # Correct for scale + if not LI is None: + Cs = LI.T*math.sqrt(scale) + else: + Cs = L.T*math.sqrt(1/scale) + + # Sample from N(0,1) + z = scp.stats.norm.rvs(size=Cs.shape[1]*n,random_state=seed).reshape(Cs.shape[1],n) + + # Sample with L + if LI is None: + z = cpp_backsolve_tr(Cs.tocsc(),scp.sparse.csc_array(z)).toarray() # actually y + + if isinstance(mu,int): + return P.T@z + + return mu[:,None] + P.T@z + + else: + # Sample with LI + if not P is None: + Cs = P.T@Cs + + if not use is None: + Cs = Cs[use,:] + + if isinstance(mu,int): + return Cs@z + + if not use is None: + mus = mu[use,None] + else: + mus = mu[:,None] + + return mus + Cs@z +