forked from betanalpha/knitr_case_studies
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstan_utility.R
110 lines (96 loc) · 3.58 KB
/
stan_utility.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Check transitions that ended with a divergence
check_div <- function(fit) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
n = sum(divergent)
N = length(divergent)
print(sprintf('%s of %s iterations ended with a divergence (%s%%)',
n, N, 100 * n / N))
if (n > 0)
print(' Try running with larger adapt_delta to remove the divergences')
}
# Check transitions that ended prematurely due to maximum tree depth limit
check_treedepth <- function(fit, max_depth = 10) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
treedepths <- do.call(rbind, sampler_params)[,'treedepth__']
n = length(treedepths[sapply(treedepths, function(x) x == max_depth)])
N = length(treedepths)
print(sprintf('%s of %s iterations saturated the maximum tree depth of %s (%s%%)',
n, N, max_depth, 100 * n / N))
if (n > 0)
print(' Run again with max_depth set to a larger value to avoid saturation')
}
# Checks the energy Bayesian fraction of missing information (E-BFMI)
check_energy <- function(fit) {
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
no_warning <- TRUE
for (n in 1:length(sampler_params)) {
energies = sampler_params[n][[1]][,'energy__']
numer = sum(diff(energies)**2) / length(energies)
denom = var(energies)
if (numer / denom < 0.2) {
print(sprintf('Chain %s: E-BFMI = %s', n, numer / denom))
no_warning <- FALSE
}
}
if (no_warning)
print('E-BFMI indicated no pathological behavior')
else
print(' E-BFMI below 0.2 indicates you may need to reparameterize your model')
}
# Checks the effective sample size per iteration
check_n_eff <- function(fit) {
fit_summary <- summary(fit, probs = c(0.5))$summary
N <- dim(fit_summary)[[1]]
iter <- dim(extract(fit)[[1]])[[1]]
no_warning <- TRUE
for (n in 1:N) {
ratio <- fit_summary[,5][n] / iter
if (ratio < 0.001) {
print(sprintf('n_eff / iter for parameter %s is %s!',
rownames(fit_summary)[n], ratio))
no_warning <- FALSE
}
}
if (no_warning)
print('n_eff / iter looks reasonable for all parameters')
else
print(' n_eff / iter below 0.001 indicates that the effective sample size has likely been overestimated')
}
# Checks the potential scale reduction factors
check_rhat <- function(fit) {
fit_summary <- summary(fit, probs = c(0.5))$summary
N <- dim(fit_summary)[[1]]
no_warning <- TRUE
for (n in 1:N) {
rhat <- fit_summary[,6][n]
if (rhat > 1.1 || is.infinite(rhat) || is.nan(rhat)) {
print(sprintf('Rhat for parameter %s is %s!',
rownames(fit_summary)[n], rhat))
no_warning <- FALSE
}
}
if (no_warning)
print('Rhat looks reasonable for all parameters')
else
print(' Rhat above 1.1 indicates that the chains very likely have not mixed')
}
check_all_diagnostics <- function(fit) {
check_n_eff(fit)
check_rhat(fit)
check_div(fit)
check_treedepth(fit)
check_energy(fit)
}
# Returns parameter arrays separated into divergent and non-divergent transitions
partition_div <- function(fit) {
nom_params <- extract(fit, permuted=FALSE)
n_chains <- dim(nom_params)[2]
params <- as.data.frame(do.call(rbind, lapply(1:n_chains, function(n) nom_params[,n,])))
sampler_params <- get_sampler_params(fit, inc_warmup=FALSE)
divergent <- do.call(rbind, sampler_params)[,'divergent__']
params$divergent <- divergent
div_params <- params[params$divergent == 1,]
nondiv_params <- params[params$divergent == 0,]
return(list(div_params, nondiv_params))
}