-
-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4922cfb
Showing
19 changed files
with
607 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# History files | ||
.Rhistory | ||
.Rapp.history | ||
*.Rbuildignore | ||
|
||
# Example code in package build process | ||
*-Ex.R | ||
|
||
# RStudio files | ||
.Rproj.user/ | ||
.Rproj.user | ||
|
||
# produced vignettes | ||
vignettes/*.html | ||
vignettes/*.pdf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Package: loo | ||
Type: Package | ||
Title: Leave-one-out cross-validation, WAIC, very good importance sampling | ||
Version: 0.1 | ||
Date: 2015-06-13 | ||
Author: Aki Vehtari, Andrew Gelman, Jonah Gabry | ||
Maintainer: Jonah Gabry <jsg2201@columbia.edu> | ||
Description: We compute leave-one-out cross-validation (LOO) using very good importance sampling (VGIS), a new procedure for regularizing importance weights. As a byproduct of our calculations, we also obtain approximate standard errors for estimated predictive errors and for comparing of predictive errors between two models. | ||
License: What license is it under? | ||
LazyData: TRUE | ||
Imports: matrixStats, | ||
parallel | ||
Suggests: rstan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Generated by roxygen2 (4.1.1): do not edit by hand | ||
|
||
export(log_lik) | ||
export(loo_and_waic) | ||
export(loo_and_waic_diff) | ||
export(vgisloo) | ||
export(vgislw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
gpdfit <- function(x) { | ||
n <- length(x) | ||
x <- sort.int(x, method = "quick") | ||
prior <- 3 | ||
m <- 80 + floor(sqrt(n)) # note: original paper used m <- 20+floor(sqrt(n)) | ||
b <- 1/x[n] + (1 - sqrt(m/seq_min_half(m)))/prior/x[floor(n/4 + 0.5)] | ||
L <- vapply1m(m, function(i) n * lx(b[i], x)) | ||
w <- vapply1m(m, function(i) 1/sum(exp(L - L[i]))) | ||
b <- sum(b*w) | ||
k <- mean.default(log(1 - b*x)) | ||
sigma <- -k/b | ||
list(k=k, sigma=sigma) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# use matrixStats::colVars instead (much faster b/c written in C++) | ||
# colVars <- function(X) { | ||
# # column variances | ||
# N <- dim(X)[[1]] | ||
# C <- dim(X)[[2]] | ||
# Xbar <- matrix(.colMeans(X, N, C), N, C, byrow = TRUE) | ||
# var <- .colMeans((X - Xbar)^2, N, C) * N / (N - 1) | ||
# var | ||
# } | ||
|
||
nlist <- function(...) { | ||
# named lists | ||
m <- match.call() | ||
out <- list(...) | ||
no_names <- is.null(names(out)) | ||
has_name <- if (no_names) FALSE else nzchar(names(out)) | ||
if (all(has_name)) return(out) | ||
nms <- as.character(m)[-1] | ||
if (no_names) { | ||
names(out) <- nms | ||
} else { | ||
names(out)[!has_name] <- nms[!has_name] | ||
} | ||
out | ||
} | ||
|
||
unlist_lapply <- function(x, fun) { | ||
unlist(lapply(x, fun), use.names = FALSE) | ||
} | ||
|
||
seq_min_half <- function(n) { | ||
seq_len(n) - 0.5 | ||
} | ||
|
||
vapply1m <- function(m, fun) { | ||
vapply(1:m, fun, FUN.VALUE = 0) | ||
} | ||
|
||
lx <- function(a, x) { | ||
k <- mean.default(log(1 - a*x)) | ||
log(-a/k) - k - 1 | ||
} | ||
|
||
qgpd <- function(p, xi=1, mu=0, beta=1, lower.tail=TRUE){ | ||
# Generalized Pareto inverse-cdf (formula from Wikipedia) | ||
if (!lower.tail) p <- 1-p | ||
mu + beta * ((1-p)^(-xi) - 1) / xi | ||
} | ||
|
||
sumlogs <- function(x, dimen=1) { | ||
# log_sum_exp | ||
x_max <- max(x) | ||
if (is.null(dim(x))) return(x_max + log(sum(exp(x-x_max)))) | ||
if (dimen == 1) return(x_max + log(colSums(exp(x-x_max)))) | ||
x_max + log(rowSums(exp(x-x_max))) | ||
} | ||
|
||
escape_check <- function(x, escape_if_greater_than = 700) { | ||
# find the difference between the largest and 2nd largest values in a matrix | ||
L <- length(x) | ||
x_sorted <- sort(x, method = "quick")[c(1:2, (L-1):L)] | ||
diffs <- vapply(c(1,3), function(i) abs(diff(x_sorted[i:(i+1)])), 0) | ||
if (any(diffs > escape_if_greater_than)) { | ||
message(paste0("Failed. Difference between largest and second largest", | ||
"log-weights is more than ", escape_if_greater_than,".")) | ||
return(invisible(NULL)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#' Convenience function for extracting log-likelihood from | ||
#' a \code{stanfit} object. | ||
#' | ||
#' @export | ||
#' @param stanfit a \code{stanfit} (\pkg{rstan}) object. | ||
#' @param parameter_name a character string naming the parameter (generated | ||
#' quantity) in the Stan model corresponding to the log-likelihood. | ||
#' @seealso \code{\link[rstan]{stanfit-class}} | ||
log_lik <- function(stanfit, parameter_name = "log_lik") { | ||
rstan_ok <- requireNamespace("rstan", quietly = TRUE) | ||
if (!rstan_ok) { | ||
message("Please install the rstan package") | ||
return(invisible(NULL)) | ||
} | ||
rstan::extract(stanfit, parameter_name)[[parameter_name]] | ||
} | ||
|
||
xx <- requireNamespace("rstan", quietly = TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#' LOO and WAIC | ||
#' | ||
#' @export | ||
#' @param log_lik an nsims by nobs matrix, typically (but not restricted to be) | ||
#' the object returned by \code{rstan::extract(stanfit, "log_lik")$log_lik}. | ||
#' @return a list. | ||
#' | ||
#' @details Leave-one-out cross-validation (LOO) and the widely applicable | ||
#' information criterion (WAIC) are methods for estimating pointwise out-of-sample | ||
#' prediction accuracy from a fitted Bayesian model using the log-likelihood | ||
#' evaluated at the posterior simulations of the parameter values. LOO and WAIC | ||
#' have various advantages over simpler estimates of predictive error such as | ||
#' AIC and DIC but are less used in practice because they involve additional | ||
#' computational steps. Here we lay out fast and stable computations for LOO and | ||
#' WAIC that can be performed using existing simulation draws. We compute LOO | ||
#' using very good importance sampling (VGIS), a new procedure for regularizing | ||
#' importance weights. As a byproduct of our calculations, we also obtain | ||
#' approximate standard errors for estimated predictive errors and for comparing | ||
#' of pre- dictive errors between two models. | ||
#' | ||
loo_and_waic <- function(log_lik) { | ||
# log_lik should be a matrix with nrow = nsims and ncol = nobs | ||
|
||
if (!is.matrix(log_lik)) stop("'log_lik' should be a matrix") | ||
S <- nrow(log_lik) | ||
N <- ncol(log_lik) | ||
lpd <- log(colMeans(exp(log_lik))) | ||
loo <- vgisloo(log_lik) | ||
elpd_loo <- loo$loos | ||
p_loo <- lpd - elpd_loo | ||
looic <- -2 * elpd_loo | ||
p_waic <- matrixStats::colVars(log_lik) | ||
elpd_waic <- lpd - p_waic | ||
waic <- -2 * elpd_waic | ||
nms <- names(pointwise <- nlist(elpd_loo, p_loo, elpd_waic, p_waic, looic, waic)) | ||
total <- unlist_lapply(pointwise, sum) | ||
se <- sqrt(N * unlist_lapply(pointwise, var)) | ||
output <- as.list(c(total, se)) | ||
names(output) <- c(nms, paste0("se_", nms)) | ||
output$pointwise <- do.call("cbind", pointwise) | ||
output$pareto_k <- loo$ks | ||
class(output) <- "loo" | ||
output | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#' Compare | ||
#' @export | ||
#' @param loo1,loo2 lists returned from \code{\link{loo_and_waic}}. | ||
#' | ||
|
||
loo_and_waic_diff <- function(loo1, loo2){ | ||
N1 <- nrow(loo1$pointwise) | ||
N2 <- nrow(loo2$pointwise) | ||
if (N1 != N2) { | ||
stop(paste("Models being compared should have the same number of data points.", | ||
"Found N1 =", N1, "and N2 =", N2)) | ||
} | ||
sqrtN <- sqrt(N1) | ||
loo_diff <- loo2$pointwise[,"elpd_loo"] - loo1$pointwise[,"elpd_loo"] | ||
waic_diff <- loo2$pointwise[,"elpd_waic"] - loo1$pointwise[,"elpd_waic"] | ||
|
||
list(elpd_loo_diff = sum(loo_diff), lpd_loo_diff = sqrtN * sd(loo_diff), | ||
elpd_waic_diff = sum(waic_diff), lpd_waic_diff = sqrtN * sd(waic_diff)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#' Leave-one-out cross-validation and WAIC | ||
#' | ||
#' @description Leave-one-out cross-validation and WAIC | ||
#' | ||
#' @section 1: After fitting a Bayesian model we often want to measure its | ||
#' predictive accuracy, for its own sake or for purposes of model comparison, | ||
#' selection, or averaging (Geisser and Eddy, 1979, Hoeting et al., 1999, | ||
#' Vehtari and Lampinen, 2002, Ando and Tsay, 2010, Vehtari and Ojanen, 2012). | ||
#' Cross- validation and information criteria are two approaches for estimating | ||
#' out-of-sample predictive accu- racy using within-sample fits (Akaike, 1973, | ||
#' Stone, 1977). In this article we consider computations using the | ||
#' log-likelihood evaluated at the usual posterior simulations of the | ||
#' parameters. Computa- tion time for the predictive accuracy measures should be | ||
#' negligible compared to the cost of fitting the model and obtaining posterior | ||
#' draws in the first place. Exact cross-validation requires re-fitting the | ||
#' model with different training sets. Approximate leave-one-out | ||
#' cross-validation (LOO) can be computed easily using importance sampling | ||
#' (Gelfand, Dey, and Chang, 1992, Gelfand, 1996) but the resulting estimate is | ||
#' noisy, as the variance of the importance weights can be large or even | ||
#' infinite (Peruggia, 1997, Epifani et al., 2008). Here we propose a novel | ||
#' approach that provides a more accurate and reliable estimate using importance | ||
#' weights that are smoothed using a Pareto distribution fit to the upper tail | ||
#' of the distribution of importance weights. WAIC (the widely applicable or | ||
#' Watanabe-Akaike information criterion; Watanabe, 2010) can be viewed as an | ||
#' improvement on the deviance information criterion (DIC) for Bayesian models. | ||
#' DIC has gained popularity in recent years in part through its implementation | ||
#' in the graphical modeling package BUGS (Spiegelhalter, Best, et al., 2002; | ||
#' Spiegelhalter, Thomas, et al., 1994, 2003), but it is known to have some | ||
#' problems, arising in part from it not being fully Bayesian in that it is | ||
#' based on a point estimate (van der Linde, 2005, Plummer, 2008). For example, | ||
#' DIC can produce negative estimates of the effective number of parameters in a | ||
#' model and it is not defined for singular models. WAIC is fully Bayesian and | ||
#' closely approximates Bayesian cross-validation. Unlike DIC, WAIC is invariant | ||
#' to parametrization and also works for singular models. WAIC is asymptotically | ||
#' equal to LOO, and can thus be used as an approximation of LOO. In the finite | ||
#' case, WAIC often gives similar estimates as LOO, but for influential | ||
#' observations WAIC underestimates the effect of leaving out one observation. | ||
#' | ||
#' @section 2: One advantage of AIC and DIC is their computational simplicity. | ||
#' In the present paper, we quickly review LOO and WAIC and then present fast | ||
#' and stable computations that can be per- formed directly on posterior | ||
#' simulations, thus allowing these newer tools to enter routine statistical | ||
#' practice. We compute LOO using very good importance sampling (VGIS), a new | ||
#' procedure for regularizing importance weights (Vehtari and Gelman, 2015). As | ||
#' a byproduct of our calculations, we also obtain approximate standard errors | ||
#' for estimated predictive errors and for comparing of predictive errors | ||
#' between two models. | ||
#' | ||
#' @docType package | ||
#' @name loo | ||
#' | ||
NULL | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#' Very good importance sampling | ||
#' | ||
#' @export | ||
#' @param log_lik, wcp, wtrunc arguments. | ||
#' @return a list. | ||
#' @details The distribution of the importance weights used in LOO may have a | ||
#' long right tail. We use the empirical Bayes estimate of Zhang and Stephens | ||
#' (2009) to fit a generalized Pareto distribution to the tail (20% largest | ||
#' importance ratios). By examining the shape parameter k of the fitted Pareto | ||
#' distribution, we are able to obtain sample based estimate of the existence | ||
#' of the moments (Koopman et al, 2009). This extends the diagnostic approach | ||
#' of Peruggia (1997) and Epifani et al. (2008) to be used routinely with | ||
#' IS-LOO for any model with factorising likelihood. Epifani et al. (2008) | ||
#' show that when estimating the leave-one-out predictive density, the central | ||
#' limit theorem holds if the variance of the weight distribution is finite. | ||
#' These results can be extended by using the generalized central limit | ||
#' theorem for stable distributions. Thus, even if the variance of the | ||
#' importance weight distribution is infinite, if the mean exists the | ||
#' estimate’s accuracy improves when additional draws are obtained. When the | ||
#' tail of the weight distribution is long, a direct use of importance | ||
#' sampling is sensitive to one or few largest values. By fitting a | ||
#' generalized Pareto distribution to the upper tail of the importance | ||
#' weights, we smooth these values. The procedure goes as follows: | ||
#' | ||
#' \enumerate{ | ||
#' \item Fit the generalized Pareto distribution to the 20% largest importance | ||
#' ratios \eqn{r_s} as computed in (6). (The computation is done separately for each | ||
#' held-out data point \eqn{i}.) In simulation experiments with a thousands to tens | ||
#' of thousands of simulation draws, we have found the fit is not sensitive to | ||
#' the specific cutoff value (for a consistent estimation the proportion of | ||
#' the samples above the cutoff should get smaller when the number of draws | ||
#' increases). | ||
#' | ||
#' \item Stabilize the importance ratios by replacing the \eqn{M} largest ratios | ||
#' by the expected values of the order statistics of the fitted generalized | ||
#' Pareto distribution \deqn{G((z - 0.5)/M), z = 1,...,M,} | ||
#' where \eqn{M} is the number of simulation draws used to fit the Pareto (in this | ||
#' case, \eqn{M = 0.2*S}) and \eqn{G} is the inverse-CDF of the generalized | ||
#' Pareto distribution. | ||
#' | ||
#' \item To guarantee finite variance of the estimate, truncate the smoothed | ||
#' ratios with \deqn{S^{3/4}\bar{w},} where \eqn{\bar{w}} is the average of | ||
#' the smoothed weights. | ||
#' } | ||
#' | ||
#' The above steps must be performed for each data point \eqn{i}, thus | ||
#' resulting in a vector of weights \eqn{w_{i}^{s}, s = 1,...,S}, for each | ||
#' \eqn{i}, which in general should be better behaved than the raw importance | ||
#' ratios \eqn{r_{i}^{s}} from which they were constructed. | ||
#' | ||
#' The results can then be combined to compute desired LOO estimates. | ||
#' | ||
|
||
vgisloo <- function(log_lik, wcp=20, wtrunc=3/4) { | ||
lw <- -log_lik | ||
temp <- vgislw(lw, wcp, wtrunc) | ||
vglw <- temp$lw | ||
vgk <- temp$k | ||
loos <- sumlogs(log_lik + vglw) | ||
loo <- sum(loos) | ||
list(loo=loo, loos=loos, ks=vgk) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#' Doing very good importance sampling | ||
#' | ||
#' @export | ||
#' @param lw,wcp,wtrunc,cores arguments | ||
#' | ||
vgislw <- function(lw, wcp=20, wtrunc=3/4, | ||
cores = parallel::detectCores()) { | ||
if (.Platform$OS.type == "windows") cores <- 1 | ||
loop_fn <- function(i) { | ||
x <- lw[,i] | ||
|
||
# divide log weights into body and right tail | ||
n <- length(x) | ||
cutoff <- quantile(x, 1 - wcp/100, names = FALSE) | ||
x_gt_cut <- x > cutoff | ||
x1 <- x[!x_gt_cut] | ||
x2 <- x[x_gt_cut] | ||
n2 <- length(x2) | ||
|
||
# store order of tail samples | ||
x2si <- order(x2) | ||
|
||
# fit generalized Pareto distribution to the right tail samples | ||
fit <- gpdfit(exp(x2) - exp(cutoff)) | ||
k <- fit$k | ||
sigma <- fit$sigma | ||
|
||
# compute ordered statistic for the fit | ||
qq <- qgpd(seq_min_half(n2)/n2, xi = k, beta = sigma) + exp(cutoff) | ||
|
||
# remap back to the original order | ||
slq <- rep.int(0, n2) | ||
slq[x2si] <- log(qq) | ||
|
||
# join body and GPD smoothed tail | ||
qx <- x | ||
qx[!x_gt_cut] <- x1 | ||
qx[x_gt_cut] <- slq | ||
if (wtrunc > 0){ | ||
# truncate too large weights | ||
lwtrunc <- wtrunc*log(n) - log(n) + sumlogs(qx) | ||
qx[qx > lwtrunc] <- lwtrunc | ||
} | ||
|
||
# renormalize weights | ||
lwx <- qx - sumlogs(qx) | ||
|
||
# return log weights and tail index k | ||
list(lwx, k) | ||
} | ||
K <- ncol(lw) | ||
K2 <- 2*K | ||
out <- parallel::mclapply(1:K, loop_fn, mc.cores = cores) | ||
ux <- unlist(out, recursive = FALSE, use.names = FALSE) | ||
lw <- do.call(cbind, ux[seq(1, K2, 2)]) | ||
kss <- do.call(c, ux[seq(2, K2, 2)]) | ||
|
||
# return log weights and tail indices k | ||
list(lw=lw, k=kss) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.onAttach <- function(...) { | ||
ver <- utils::packageVersion("loo") | ||
msg <- paste("loo version", ver) | ||
packageStartupMessage(msg) | ||
} |
Oops, something went wrong.