Skip to content

Commit

Permalink
finalize jsdgam wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Clark committed Oct 30, 2024
1 parent 4fbf3aa commit 0a738fd
Show file tree
Hide file tree
Showing 104 changed files with 2,160 additions and 637 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ jobs:
shell: Rscript {0}

- name: Test coverage
run: covr::codecov()
run: covr::codecov(line_exclusions = "R/stan_utils.R")
shell: Rscript {0}
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ export(gp)
export(hindcast)
export(hypotheses)
export(irf)
export(jsdgam)
export(lfo_cv)
export(lognormal)
export(loo)
Expand Down
9 changes: 5 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# mvgam 1.1.4 (development version; not yet on CRAN)
## New functionalities
* Added a `stability.mvgam` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added function `jsdgam()` to estimate Joint Species Distribution Models in which both the latent factors and the observation model components can include any of mvgam's complex linear predictor effects. See `?mvgam::jsdgam` for details
* Added a `stability.mvgam()` method to compute stability metrics from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added functionality to estimate hierarchical error correlations when using multivariate latent process models and when the data are nested among levels of a relevant grouping factor (#75); see `?mvgam::AR` for an example
* Added `ZMVN()` error models for estimating Zero-Mean Multivariate Normal errors; convenient for working with non time-series data where latent residuals are expected to be correlated (such as when fitting Joint Species Distribution Models); see `?mvgam::ZMVN` for examples
* Added a `fevd.mvgam` method to compute forecast error variance decompositions from models fit with Vector Autoregressive dynamics (#21 and #76)
* Added a `fevd.mvgam()` method to compute forecast error variance decompositions from models fit with Vector Autoregressive dynamics (#21 and #76)

## Bug fixes
* Fixed a minor bug in the way `trend_map` recognises levels of the `series` factor
Expand All @@ -14,8 +15,8 @@
* Allow intercepts to be included in process models when `trend_formula` is supplied. This breaks the assumption that the process has to be zero-centred, adding more modelling flexibility but also potentially inducing nonidentifiabilities with respect to any observation model intercepts. Thoughtful priors are a must for these models
* Added `standata.mvgam_prefit`, `stancode.mvgam` and `stancode.mvgam_prefit` methods for better alignment with 'brms' workflows
* Added 'gratia' to *Enhancements* to allow popular methods such as `draw()` to be used for 'mvgam' models if 'gratia' is already installed
* Added an `ensemble.mvgam_forecast` method to generate evenly weighted combinations of probabilistic forecast distributions
* Added an `irf.mvgam` method to compute Generalized and Orthogonalized Impulse Response Functions (IRFs) from models fit with Vector Autoregressive dynamics
* Added an `ensemble.mvgam_forecast()` method to generate evenly weighted combinations of probabilistic forecast distributions
* Added an `irf.mvgam()` method to compute Generalized and Orthogonalized Impulse Response Functions (IRFs) from models fit with Vector Autoregressive dynamics

## Deprecations
* The `drift` argument has been deprecated. It is now recommended for users to include parametric fixed effects of "time" in their respective GAM formulae to capture any expected drift effects
Expand Down
32 changes: 22 additions & 10 deletions R/forecast.mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,24 @@ forecast.mvgam = function(object,
type <- match.arg(arg = type, choices = c("link", "response",
"trend", "expected",
"detection", "latent_N"))

if(inherits(object, 'jsdgam')){
orig_trend_model <- attr(object$model_data, 'prepped_trend_model')
} else {
orig_trend_model <- object$trend_model
}

data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
n_series <- NCOL(object$ytimes)

# Check whether a forecast has already been computed
forecasts_exist <- FALSE
if(!is.null(object$test_data) && !missing(data_test)){
object$test_data <- validate_series_time(object$test_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
data_test <- validate_series_time(data_test,
trend_model = object$trend_model)
trend_model = orig_trend_model)
if(max(data_test$index..time..index) <=
max(object$test_data$index..time..index)){
forecasts_exist <- TRUE
Expand Down Expand Up @@ -126,7 +133,7 @@ forecast.mvgam = function(object,

if(is.null(object$test_data)){
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
data.frame(series = object$obs_data$series,
time = object$obs_data$time) %>%
dplyr::group_by(series) %>%
Expand Down Expand Up @@ -176,7 +183,7 @@ forecast.mvgam = function(object,
data_test$y <- rep(NA, NROW(data_test))
}
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
}

# Generate draw-specific forecasts
Expand All @@ -198,7 +205,7 @@ forecast.mvgam = function(object,

# Extract hindcasts
data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2],
length.out = NCOL(object$ytimes) + 1)
starts <- ends + 1
Expand Down Expand Up @@ -330,12 +337,12 @@ forecast.mvgam = function(object,
} else {
# If forecasts already exist, simply extract them
data_test <- validate_series_time(object$test_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
last_train <- max(object$obs_data$index..time..index) -
(min(object$obs_data$index..time..index) - 1)

data_train <- validate_series_time(object$obs_data,
trend_model = object$trend_model)
trend_model = orig_trend_model)
ends <- seq(0, dim(mcmc_chains(object$model_output, 'ypred'))[2],
length.out = NCOL(object$ytimes) + 1)
starts <- ends + 1
Expand Down Expand Up @@ -593,8 +600,13 @@ forecast_draws = function(object,

# Check arguments
validate_pos_integer(n_cores)
if(inherits(object, 'jsdgam')){
orig_trend_model <- attr(object$model_data, 'prepped_trend_model')
} else {
orig_trend_model <- object$trend_model
}
data_test <- validate_series_time(data_test, name = 'newdata',
trend_model = object$trend_model)
trend_model = orig_trend_model)
data_test <- sort_data(data_test)
n_series <- NCOL(object$ytimes)
use_lv <- object$use_lv
Expand Down Expand Up @@ -695,7 +707,7 @@ forecast_draws = function(object,

# No need to compute in parallel if there was no trend model
nmix_notrend <- FALSE
if(!inherits(object$trend_model, 'mvgam_trend') &
if(!inherits(orig_trend_model, 'mvgam_trend') &
object$family == 'nmix'){
nmix_notrend <- TRUE
}
Expand Down
16 changes: 10 additions & 6 deletions R/index-mvgam.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ variables.mvgam = function(x, ...){
# Linear predictor parameters
observation_linpreds <- data.frame(orig_name = parnames[grepl('mus[',
parnames,
fixed = TRUE)],
fixed = TRUE) &
!grepl('trend_mus[',
parnames,
fixed = TRUE)],
alias = NA)

if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
trend_linpreds <- data.frame(orig_name = parnames[grepl('trend_mus[',
parnames,
fixed = TRUE)],
Expand All @@ -71,7 +74,7 @@ variables.mvgam = function(x, ...){
mgcv_names <- names(coef(x$mgcv_model))
observation_betas <- data.frame(orig_name = b_names, alias = mgcv_names)

if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
b_names <- colnames(mcmc_chains(x$model_output, 'b_trend'))
mgcv_names <- gsub('series', 'trend',
paste0(names(coef(x$trend_mgcv_model)), '_trend'))
Expand All @@ -97,7 +100,7 @@ variables.mvgam = function(x, ...){
}

trend_re_params <- NULL
if(!is.null(x$trend_call)){
if(!is.null(x$trend_call) & !inherits(x, 'jsdgam')){
if(any(unlist(purrr::map(x$trend_mgcv_model$smooth, inherits, 'random.effect')))){
re_labs <- unlist(lapply(purrr::map(x$trend_mgcv_model$smooth, 'term'),
paste, collapse = ','))[
Expand Down Expand Up @@ -125,7 +128,7 @@ variables.mvgam = function(x, ...){
observation_smoothpars <- NULL
}

if(any(grepl('rho_trend[', parnames, fixed = TRUE))){
if(any(grepl('rho_trend[', parnames, fixed = TRUE)) & !inherits(x, 'jsdgam')){
trend_smoothpars <- data.frame(orig_name = parnames[grepl('rho_trend[',
parnames,
fixed = TRUE)],
Expand All @@ -136,7 +139,8 @@ variables.mvgam = function(x, ...){

# Trend state parameters
if(any(grepl('trend[', parnames, fixed = TRUE) &
!grepl('_trend[', parnames, fixed = TRUE))){
!grepl('_trend[', parnames, fixed = TRUE)) &
!inherits(x, 'jsdgam')){
trend_states <- grepl('trend[', parnames, fixed = TRUE) &
!grepl('_trend[', parnames, fixed = TRUE)
trends <- data.frame(orig_name = parnames[trend_states],
Expand Down
Loading

0 comments on commit 0a738fd

Please sign in to comment.