Skip to content

Commit

Permalink
added filter_traj option to bpfilter
Browse files Browse the repository at this point in the history
  • Loading branch information
ionides committed Nov 30, 2023
1 parent 51ec228 commit 767499e
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 32 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: spatPomp
Type: Package
Title: Inference for Spatiotemporal Partially Observed Markov Processes
Version: 0.33.0
Date: 2023-07-28
Version: 0.33.1
Date: 2023-11-30
Authors@R: c(
person("Kidus", "Asfaw", email = "kidusasfaw1990@gmail.com", role = c("aut")),
person("Edward", "Ionides", email = "ionides@umich.edu",role = c("cre","aut")),
Expand Down
79 changes: 71 additions & 8 deletions R/bpfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
##' an integer vector of neighboring units.
##' @param save_states logical. If True, the state-vector for each particle and
##' block is saved.
##' @param filter_traj logical; if \code{TRUE}, a filtered trajectory is returned for the state variables and parameters.
##' @param \dots If a \code{params} argument is specified, \code{bpfilter} will estimate the likelihood at that parameter set instead of at \code{coef(object)}.
##'
##' @examples
Expand Down Expand Up @@ -74,7 +75,8 @@ setClass(
cond.loglik="numeric",
block.cond.loglik="array",
loglik="numeric",
saved.states="list"
saved.states="list",
filter.traj="array"
),
prototype=prototype(
block_list = list(),
Expand All @@ -83,7 +85,8 @@ setClass(
cond.loglik=as.double(NA),
block.cond.loglik=array(data=numeric(0),dim=c(0,0)),
loglik=as.double(NA),
saved.states=list()
saved.states=list(),
filter.traj=array(data=numeric(0),dim=c(0,0,0))
)
)

Expand Down Expand Up @@ -125,10 +128,11 @@ setMethod(
setMethod(
"bpfilter",
signature=signature(object="spatPomp"),
function (object, Np, block_size, block_list, save_states, ..., verbose=getOption("verbose", FALSE)) {
function (object, Np, block_size, block_list, save_states, filter_traj, ..., verbose=getOption("verbose", FALSE)) {
ep = paste0("in ",sQuote("bpfilter"),": ")

if(missing(save_states)) save_states <- FALSE
if(missing(filter_traj)) filter_traj <- FALSE

if(missing(block_list) && missing(block_size))
stop(ep,sQuote("block_list"), " or ", sQuote("block_size"), " must be specified to the call",call.=FALSE)
Expand Down Expand Up @@ -156,6 +160,7 @@ setMethod(
Np=Np,
block_list=block_list,
save_states=save_states,
filter_traj=filter_traj,
...,
verbose=verbose)
}
Expand All @@ -168,10 +173,11 @@ setMethod(
setMethod(
"bpfilter",
signature=signature(object="bpfilterd_spatPomp"),
function (object, Np, block_size, block_list, save_states, ..., verbose=getOption("verbose", FALSE)) {
function (object, Np, block_size, block_list, save_states, filter_traj, ..., verbose=getOption("verbose", FALSE)) {
ep = paste0("in ",sQuote("bpfilter"),": ")

if(missing(save_states)) save_states <- FALSE
if(missing(filter_traj)) filter_traj <- FALSE

if (!missing(block_list) & !missing(block_size)){
stop(ep,"Exactly one of ",sQuote("block_size"), " and ", sQuote("block_list"), " can be provided, but not both.",call.=FALSE)
Expand All @@ -197,16 +203,18 @@ setMethod(
Np=Np,
block_list=block_list,
save_states=save_states,
filter_traj=filter_traj,
...,
verbose=verbose)
}
)

bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose, .gnsi = TRUE) {
bpfilter.internal <- function (object, Np, block_list, save_states, filter_traj, ..., verbose, .gnsi = TRUE) {
ep <- paste0("in ",sQuote("bpfilter"),": ")
verbose <- as.logical(verbose)
p_object <- pomp(object,...)
save_states <- as.logical(save_states)
filter_traj <- as.logical(filter_traj)
object <- new("spatPomp",p_object,
unit_covarnames = object@unit_covarnames,
shared_covarnames = object@shared_covarnames,
Expand Down Expand Up @@ -253,6 +261,7 @@ bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose
## returns an nvars by nsim matrix
init.x <- rinit(object,params=params,nsim=Np[1L],.gnsi=gnsi)
statenames <- rownames(init.x)
nvars <- nrow(init.x)
x <- init.x

# create array to store weights per particle per block_list
Expand All @@ -265,6 +274,32 @@ bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose
saved.states <- list()
}

## set up storage for saving samples from filtering distributions

## bpfilter has a logical "saved.states"
## pomp::pfilter has an option to save weighted filter states
## this is not implemented yet in spatPomp::bpfilter
## weighted states require interacting with the block structure, which
## saved.states does not need
## stsav and wtsav are included as a stub in case the functionality is added later
## stsav <- save.states %in% c("unweighted","TRUE")
## wtsav <- save.states == "weighted"
stsav <- FALSE
wtsav <- FALSE
if (stsav || wtsav || filter_traj) {
xparticles <- matrix(vector(mode="list"),nrow=ntimes,ncol=nblocks)
## if (wtsav) xweights <- xparticles
}
if (filter_traj) {
pedigree <- matrix(vector(mode="list"),nrow=ntimes+1,ncol=nblocks)
}

if (filter_traj) {
filt.t <- array(data=numeric(1),dim=c(nvars,1,ntimes+1),
dimnames=list(name=statenames,rep=1,time=NULL))
} else {
filt.t <- array(data=numeric(0),dim=c(0,0,0))
}

for (nt in seq_len(ntimes)) { ## main loop
## advance the state variables according to the process model
Expand Down Expand Up @@ -316,13 +351,13 @@ bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose
us = object@unit_statenames
statenames = paste0(rep(us,length(block)),rep(block,each=length(us)))
tempX = X[statenames,,,drop = FALSE]
xx <- tryCatch( #resampling with cross pollination
xx <- tryCatch( #block resampling
.Call(
"bpfilter_computations",
x=tempX,
params=params,
Np=Np[nt+1],
trackancestry=FALSE,
trackancestry=filter_traj,
doparRS=FALSE,
weights=weights[i,]
),
Expand All @@ -332,13 +367,40 @@ bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose
)
x[statenames,] <- xx$states
params <- xx$params
if (filter_traj) pedigree[nt,i][[1]] <- xx$ancestry
if (stsav || filter_traj) {
xparticles[nt,i][[1]] <- xx$states
dimnames(xparticles[nt,i][[1]]) <- list(name=statenames,.id=NULL)
}

}
if (save_states) saved.states[[nt]] <- x
log_weights = max_log_d + log(weights)
block_log_weights <- apply(log_weights,1,logmeanexp)
loglik[nt] = sum(block_log_weights)
block.loglik[,nt] <- block_log_weights
} ## end of main loop

if (filter_traj) { ## select a single trajectory
# sample sequentially for each block
for(i in seq(nblocks)){
block <- block_list[[i]]
us = object@unit_statenames
block_statenames = paste0(rep(us,length(block)),rep(block,each=length(us)))
b <- sample.int(n=ncol(weights),size=1L,replace=TRUE)
filt.t[block_statenames,1L,ntimes+1] <- xparticles[ntimes,i][[1]][,b]
for (nt in seq.int(from=ntimes-1,to=1L,by=-1L)) {
b <- pedigree[nt+1,][[1]][b]
filt.t[block_statenames,1L,nt+1] <- xparticles[nt,i][[1]][,b]
}
if (times[2L] > times[1L]) {
b <- pedigree[1L,i][[1]][b]
filt.t[block_statenames,1L,1L] <- init.x[block_statenames,b]
}
}
if (times[2L] <= times[1L]) filt.t <- filt.t[,,-1L,drop=FALSE]
}

new(
"bpfilterd_spatPomp",
object,
Expand All @@ -347,6 +409,7 @@ bpfilter.internal <- function (object, Np, block_list, save_states, ..., verbose
cond.loglik=loglik,
block.cond.loglik=block.loglik,
loglik=sum(loglik),
saved.states=saved.states
saved.states=saved.states,
filter.traj=filt.t
)
}
4 changes: 4 additions & 0 deletions man/bpfilter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions tests/README
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@

The unit tests here are designed to check for unintended consequences of code changes. The goal is to have 100% coverage and to raise a flag when computations are changed.

1. Longer tests are needed to check whether the code provides numerically correct answers in situations where these are available. Some of those validations are carried out by Asfaw et al (2021). This code, and other examples, may be added later as package vignettes.
1. Longer tests are needed to check whether the code provides numerically correct answers in situations where these are available. Some of those validations are carried out by Asfaw et al (2023). This code, and other examples, may be added later as package vignettes.

2. For code-generated code (e.g., R code which writes C functions that are then compiled) covr checks whether the generating code was run, and it can test whether the generatd code has changed. However, it does not check whether the generated code was run. In the context of spatPomp, full testing necessitates running all the compiled code.

3. Additional checks are carried out via a flag which defaults to extended=FALSE.
3. Additional checks are carried out in the tests/xtests directory, and are not run by default.

A call to igirf using the moment-based guide function can test compiled code for eunit_measure, munit_measure, vunit_measure, dunit_measure, runit_measure, rprocess, skeleton, rinit and partrans.

22-08-06: covr ran on an intel Mac but threw a compiler error on an M1 Mac. Not sure yet if this is an architecture thing, or some other issue.
22-08-06: covr ran on an intel Mac but threw a compiler error on an M1 Mac. This is apparently an architecture issue.

References

Asfaw, K., Park, J., Ho, A., King, A. A., and Ionides, E. L. (2021). Partially observed Markov processes with spatial structure via the R package spatPomp. (https://arxiv.org/abs/2101.01157)
Asfaw, K., Park, J., King, A. A., and Ionides, E. L. (2023). Partially observed Markov processes with spatial structure via the R package spatPomp. (https://arxiv.org/abs/2101.01157)


13 changes: 13 additions & 0 deletions tests/bm.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,30 @@ b_bpfilter_repeat <- bpfilter(b_bpfilter)
paste("check bpfilter on bpfilterd_spatPomp: ",
logLik(b_bpfilter)==logLik(b_bpfilter_repeat))

set.seed(5)
b_bpfilter_filter_traj <- bpfilter(b_bpfilter,filter_traj=TRUE)
paste("bpfilter filter trajectory final particle: ")
round(b_bpfilter_filter_traj@filter.traj[,1,N+1],3)


set.seed(5)
b_bpfilter_save_states <- bpfilter(b_bpfilter,save_states=TRUE)
paste("bpfilter final particles: ")
round(b_bpfilter_save_states@saved.states[[N]],3)

##
## enkf tested on bm
##

set.seed(5)
b_enkf <- enkf(b_model, Np = Np)
paste("bm enkf loglik: ",round(logLik(b_enkf),10))

##
## girf tested on bm, both moment and bootstrap methods
##

set.seed(0)
b_girf_mom <- girf(b_model,Np = floor(Np/2),lookahead = 1,
Nguide = floor(Np/2),
kind = 'moment',Ninter=2)
Expand Down
25 changes: 23 additions & 2 deletions tests/bm.Rout.save
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,44 @@ executing %dopar% sequentially: no parallel backend registered
+ logLik(b_bpfilter)==logLik(b_bpfilter_repeat))
[1] "check bpfilter on bpfilterd_spatPomp: TRUE"
>
> set.seed(5)
> b_bpfilter_filter_traj <- bpfilter(b_bpfilter,filter_traj=TRUE)
> paste("bpfilter filter trajectory final particle: ")
[1] "bpfilter filter trajectory final particle: "
> round(b_bpfilter_filter_traj@filter.traj[,1,N+1],3)
X1 X2
2.349 0.796
>
>
> set.seed(5)
> b_bpfilter_save_states <- bpfilter(b_bpfilter,save_states=TRUE)
> paste("bpfilter final particles: ")
[1] "bpfilter final particles: "
> round(b_bpfilter_save_states@saved.states[[N]],3)
.id
name [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
X1 2.611 2.236 2.186 2.186 2.349 1.077 1.077 1.155 1.442 1.063
X2 0.796 0.442 0.442 0.231 0.231 0.231 0.648 0.648 1.337 1.019
>
> ##
> ## enkf tested on bm
> ##
>
> set.seed(5)
> b_enkf <- enkf(b_model, Np = Np)
> paste("bm enkf loglik: ",round(logLik(b_enkf),10))
[1] "bm enkf loglik: -11.038124567"
[1] "bm enkf loglik: -10.8897941231"
>
> ##
> ## girf tested on bm, both moment and bootstrap methods
> ##
>
> set.seed(0)
> b_girf_mom <- girf(b_model,Np = floor(Np/2),lookahead = 1,
+ Nguide = floor(Np/2),
+ kind = 'moment',Ninter=2)
> paste("bm girf loglik, moment guide: ",round(logLik(b_girf_mom),10))
[1] "bm girf loglik, moment guide: -13.2689915662"
[1] "bm girf loglik, moment guide: -14.9828908185"
>
> set.seed(0)
> b_girf_boot <- girf(b_model,Np = floor(Np/2),lookahead = 1,
Expand Down
6 changes: 6 additions & 0 deletions tests/xtests/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

## This directory is for extra tests of correctness
## requiring addition Monte Carlo intensity, which we do not
## want to run every time we do the unit tests to check nothing has
## been broken.

18 changes: 18 additions & 0 deletions tests/xtests/he10.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## extra tests of correctness requiring addition Monte Carlo intensity

set.seed(42)
library(spatPomp)


print("Test he10 with towns_selected argument")
h4 <- he10(U=4,towns_selected=c(1,2,11,12),
basic_params = c(
alpha =0.99, iota=0, R0=30,
cohort=0.5, amplitude=0.3, gamma=52,
sigma=52, mu=0.02, sigmaSE=0.05,
rho=0.5, psi=0.1, g=800,
S_0=0.036, E_0=0.00007, I_0=0.00006
))
s4 <- simulate(h4,seed=27)
obs(s4)[,1:2]

18 changes: 2 additions & 16 deletions tests/xtests.R → tests/xtests/pfilter.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
## extra tests of correctness requiring addition Monte Carlo intensity

XTEST <- FALSE
set.seed(42)
library(spatPomp)

if(XTEST){

print("Test PF/KF consistency for bm")
bb <- bm(U=4,N=10)
bb_pf_loglik <- round(logLik(pfilter(bb,1000)),10)
Expand All @@ -30,19 +27,8 @@ if(XTEST){
print(paste("bm3 pfilter loglik: ",bb3_pf_loglik))
print(paste("bm3 kalman filter loglik: ",round(bm2_kalman_logLik(bb3),10)))

print("Test he10 with towns_selected argument")
h4 <- he10(U=4,towns_selected=c(1,2,11,12),
basic_params = c(
alpha =0.99, iota=0, R0=30,
cohort=0.5, amplitude=0.3, gamma=52,
sigma=52, mu=0.02, sigmaSE=0.05,
rho=0.5, psi=0.1, g=800,
S_0=0.036, E_0=0.00007, I_0=0.00006
))
s4 <- simulate(h4,seed=27)
obs(s4)[,1:2]

}





0 comments on commit 767499e

Please sign in to comment.