Skip to content

Commit

Permalink
Merge pull request #147 from 4DModeller/modelParse
Browse files Browse the repository at this point in the history
1. Adds map plot of model output to priors app
  • Loading branch information
gareth-j authored Oct 9, 2023
2 parents f33dd74 + 0d3c936 commit b4300b2
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Clearer documentation on types expected by the `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- Checks on the data types being passed into the `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- Ability to plot either polygon or point data on Leaflet map of `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- The ability to plot model predictions on a `leaflet` map in the our [Interactive priors Shiny app](https://4dmodeller.github.io/fdmr/articles/priors_app.html) - [PR #147](https://github.com/4DModeller/fdmr/pull/147)

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Imports:
shinybusy,
curl,
promises,
fmesher
fmesher,
purrr
Suggests:
bookdown,
knitr,
Expand All @@ -64,4 +65,3 @@ Suggests:
rcmdcheck
Config/testthat/edition: 3
VignetteBuilder: knitr

3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Generated by roxygen2: do not edit by hand

export(clear_caches)
export(create_prediction_field)
export(get_tmpdir)
export(get_tutorial_datapath)
export(interactive_priors)
export(latlong_to_utm)
export(load_tutorial_data)
export(mesh_builder)
export(mesh_checker)
export(mesh_to_spatial)
export(model_builder)
export(numbers_only)
export(parse_model_output)
export(plot_barchart)
Expand Down
57 changes: 56 additions & 1 deletion R/model_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ parse_model_output_bru <- function(model_output, measurement_data) {
fitted_mean_post <- model_output$summary.fitted.values$mean[seq_len(nrow(measurement_data))]
fitted_sd_post <- model_output$summary.fitted.values$sd[seq_len(nrow(measurement_data))]

random_effect_fields <- model_output$summary.random$f$mean
mean_post <- model_output$summary.random$f$mean
sd_post <- model_output$summary.random$f$sd
fixed_mean <- model_output$summary.fixed$mean

dic <- model_output$dic$dic
pars <- model_output$marginals.hyperpar

parsed_output <- list(
fitted_mean_post = fitted_mean_post,
fitted_sd_post = fitted_sd_post,
random_effect_fields = random_effect_fields,
mean_post = mean_post,
sd_post = sd_post,
fixed_mean = fixed_mean,
Expand All @@ -43,3 +44,57 @@ parse_model_output <- function(model_output, measurement_data, model_type = "inl
return(parse_model_output_bru(model_output = model_output, measurement_data = measurement_data))
}
}


#' Create a prediction field from the parsed model output and the mesh
#'
#' @param mesh INLA mesh
#' @param plot_type Type of plot to create, "predicted_mean_fields" etc
#' @param data_dist Type of data, "poisson" etc
#' @param var_a Data for variable a, required for "predicted_mean_fields" and "random_effect_fields"
#' @param var_b Data for variable b, required for "predicted_mean_fields"
#'
#' @return data.frame
#' @export
create_prediction_field <- function(mesh,
plot_type = "predicted_mean_fields",
data_dist = "poisson",
var_a = NULL,
var_b = NULL) {
valid_plots <- c("predicted_mean_fields", "random_effect_fields")
if (!(plot_type %in% valid_plots)) {
stop("Invalid plot type, select from ", valid_plots)
}

valid_data_dists <- c("poisson", "gaussian")
if (!(data_dist %in% valid_data_dists)) {
stop("Invalid data type, select from ", valid_data_dists)
}

if (plot_type == "predicted_mean_fields" && is.null(var_b)) {
stop("var_b must be provided for predicted_mean_fields plot")
}

mod_proj <- fmesher::fm_evaluator(mesh)
xy_grid <- base::expand.grid(mod_proj$x, mod_proj$y)
A_proj <- INLA::inla.spde.make.A(mesh = mesh, loc = as.matrix(xy_grid))

if (plot_type == "predicted_mean_fields") {
if (data_dist == "poisson") {
z <- base::exp(base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b))
} else {
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b)
}
} else {
# We get an error here as we only have 265 items
# z <- var_a[1:mesh$n]
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n])
}

base::data.frame(x = xy_grid[, 1], y = xy_grid[, 2], z = z)
}


create_raster <- function(dataframe, crs) {
raster::rasterFromXYZ(dataframe, crs = crs)
}
172 changes: 149 additions & 23 deletions R/shiny_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@
#'
#' @return shiny::app
#' @keywords internal
priors_shiny <- function(spatial_data,
measurement_data,
time_variable,
mesh,
log_folder = NULL) {
model_builder_shiny <- function(spatial_data,
measurement_data,
time_variable,
mesh,
log_folder = NULL) {
future::plan(future::multisession())

got_coords <- has_coords(spatial_data = spatial_data)
if (!got_coords) {
stop("Please make sure you have set coordinates on spatial_data using sp::coordinates.")
}

spatial_crs <- sp::proj4string(spatial_data)
mesh_crs <- mesh$crs$input

if (is.na(mesh_crs) && is.na(spatial_crs)) {
warning("Cannot read CRS from mesh or spatial_data, using default CRS = +proj=longlat +datum=WGS84")
crs <- "+proj=longlat +datum=WGS84"
} else if (is.na(mesh_crs)) {
crs <- spatial_crs
} else {
crs <- mesh_crs
}

brewer_palettes <- RColorBrewer::brewer.pal.info
default_colours <- rownames(brewer_palettes[brewer_palettes$cat == "seq", ])

# Text for priors help
prior_range_text <- "A length 2 vector, with (range0, Prange) specifying that P(ρ < ρ_0)=p_ρ,
where ρ is the spatial range of the random field."
Expand Down Expand Up @@ -128,8 +143,17 @@ priors_shiny <- function(spatial_data,
type = "tabs",
shiny::tabPanel(
"Features",
shiny::selectInput(inputId = "model_var", label = "Model variable", choices = features),
shiny::selectInput(inputId = "exposure_param", label = "Exposure param", choices = features),
shiny::fluidRow(
shiny::column(
6,
shiny::selectInput(inputId = "model_var", label = "Model variable", choices = features),
shiny::selectInput(inputId = "exposure_param", label = "Exposure (time variable)", choices = features),
),
shiny::column(
6,
shiny::selectInput(inputId = "data_dist", label = "Data distribution", choices = c("Poisson", "Gaussian")),
)
),
shiny::checkboxGroupInput(inputId = "features", label = "Features", choices = features),
shiny::checkboxInput(inputId = "f_func", label = "Add f()", value = FALSE),
shiny::actionButton(inputId = "clear", label = "Clear"),
Expand Down Expand Up @@ -157,9 +181,34 @@ priors_shiny <- function(spatial_data,
shiny::selectInput(inputId = "plot_type", label = "Plot type:", choices = plot_choices, selected = plot_choices[1]),
shiny::plotOutput(outputId = "plot_model_out")
),
shiny::tabPanel(
"Map",
shiny::fluidRow(
shiny::column(
6,
shiny::selectInput(inputId = "map_plot_type", label = "Plot type", choices = c("Predicted mean fields", "Random effect fields"), selected = "Predicted mean fields"),
shiny::selectInput(inputId = "select_run_map", label = "Select run:", choices = c())
),
shiny::column(
6,
shiny::selectInput(
inputId = "colour_category",
label = "Palette type",
choices = c("Sequential", "Diverging", "Qualitative", "Viridis"),
selected = "Viridis"
),
shiny::selectInput(
inputId = "colour_scheme",
label = "Colour Scheme",
choices = default_colours,
),
)
),
leaflet::leafletOutput(outputId = "map_out")
),
shiny::tabPanel(
"Code",
shiny::selectInput(inputId = "select_run", label = "Select run:", choices = c()),
shiny::selectInput(inputId = "select_run_code", label = "Select run:", choices = c()),
shiny::verbatimTextOutput(outputId = "code_out")
),
shiny::tabPanel(
Expand Down Expand Up @@ -201,11 +250,9 @@ priors_shiny <- function(spatial_data,
})

shiny::observe({
shiny::updateSelectInput(session = session, inputId = "select_run", choices = run_names())
})

shiny::observeEvent(input$features, {
print(paste0("You have chosen: ", input$features))
shiny::updateSelectInput(session = session, inputId = "select_run_map", choices = run_names())
shiny::updateSelectInput(session = session, inputId = "select_run_code", choices = run_names())
shiny::updateSelectInput(session, inputId = "colour_scheme", label = "Colours", choices = category_colours())
})

shiny::observeEvent(input$clear, {
Expand Down Expand Up @@ -273,20 +320,30 @@ priors_shiny <- function(spatial_data,
formula_str()
})

data_distribution <- shiny::reactive({
tolower(input$data_dist)
})

shiny::observeEvent(input$run_model, ignoreNULL = TRUE, {
exposure_param_local <- input$exposure_param
formula_local <- inla_formula()
measurement_data_local <- measurement_data

data_dist_local <- data_distribution()
family_control <- NULL
if (data_dist_local == "poisson") {
family_control <- list(link = "log")
}

promise <- promises::future_promise(
{
# Without loading INLA here we get errors
require("INLA")
inlabru::bru(formula_local,
data = measurement_data_local,
family = "poisson",
family = data_dist_local,
E = measurement_data_local[[exposure_param_local]],
control.family = list(link = "log"),
control.family = family_control,
options = list(
verbose = FALSE
)
Expand All @@ -300,8 +357,10 @@ priors_shiny <- function(spatial_data,
function(model_output) {
# Run the model
run_no(run_no() + 1)
model_vals$model_outputs[[run_no()]] <- model_output
model_vals$parsed_outputs[[run_no()]] <- parse_model_output(
run_label <- paste0("Run-", run_no())

model_vals$model_outputs[[run_label]] <- model_output
model_vals$parsed_outputs[[run_label]] <- parse_model_output(
model_output = model_output,
measurement_data = measurement_data
)
Expand All @@ -316,7 +375,6 @@ priors_shiny <- function(spatial_data,
"pg_ar1" = input$pg_ar1
)

run_label <- paste0("Run-", run_no())
model_vals$run_params[[run_label]] <- run_params

if (write_logs) {
Expand Down Expand Up @@ -355,6 +413,69 @@ priors_shiny <- function(spatial_data,
rownames = TRUE
)

category_colours <- shiny::reactive({
if (input$colour_category == "Viridis") {
colours <- c("viridis", "magma", "inferno", "plasma")
} else {
palettes_mapping <- list("Sequential" = "seq", "Diverging" = "div", "Qualitative" = "qual")
chosen_cat <- palettes_mapping[input$colour_category]
colours <- rownames(subset(RColorBrewer::brewer.pal.info, category %in% chosen_cat))
}
colours
})

colour_scheme <- shiny::reactive({
input$colour_scheme
})


prediction_field <- shiny::reactive({
if (length(model_vals$parsed_outputs) == 0) {
return()
}

data <- model_vals$parsed_outputs[[input$select_run_map]]
if (input$map_plot_type == "Predicted mean fields") {
create_prediction_field(
mesh = mesh,
plot_type = "predicted_mean_fields",
data_dist = data_distribution(),
var_a = data[["mean_post"]],
var_b = data[["fixed_mean"]]
)
} else {
create_prediction_field(
mesh = mesh,
plot_type = "random_effect_fields",
data_dist = data_distribution(),
var_a = data[["mean_post"]]
)
}
})

z_values <- shiny::reactive({
prediction_field()[["z"]]
})

map_raster <- shiny::reactive({
raster::rasterFromXYZ(prediction_field(), crs = crs)
})

map_colours <- shiny::reactive({
leaflet::colorNumeric(palette = colour_scheme(), domain = z_values(), reverse = FALSE)
})

output$map_out <- leaflet::renderLeaflet({
if (is.null(map_raster())) {
return()
}

leaflet::leaflet() %>%
leaflet::addTiles(group = "OSM") %>%
leaflet::addRasterImage(map_raster(), colors = map_colours(), opacity = 0.9, group = "Raster") %>%
leaflet::addLegend(position = "topright", pal = map_colours(), values = z_values())
})

model_plot <- shiny::eventReactive(input$plot_type, ignoreNULL = FALSE, {
if (length(model_vals$parsed_outputs) == 0) {
return()
Expand Down Expand Up @@ -401,7 +522,12 @@ priors_shiny <- function(spatial_data,
return()
}

params <- model_vals$run_params[[input$select_run]]
params <- model_vals$run_params[[input$select_run_code]]

family_control_str <- "NULL"
if (data_distribution() == "poisson") {
family_control_str <- "list(link = 'log'),"
}

paste0(
"spde <- INLA::inla.spde2.pcmatern(
Expand All @@ -416,9 +542,9 @@ priors_shiny <- function(spatial_data,
)", "\n\n",
paste0("model_output <- inlabru::bru(formula,
data = measurement_data,
family = 'poisson',
family = '", data_distribution(), "',
E = measurement_data[[", input$exposure_param, "]],
control.family = list(link = 'log'),
control.family = ", family_control_str, "
options = list(
verbose = FALSE
)
Expand All @@ -441,6 +567,6 @@ priors_shiny <- function(spatial_data,
#'
#' @return shiny::app
#' @export
interactive_priors <- function(spatial_data, measurement_data, time_variable, mesh, log_folder = NULL) {
shiny::runApp(priors_shiny(spatial_data = spatial_data, measurement_data = measurement_data, time_variable = time_variable, mesh = mesh, log_folder = log_folder))
model_builder <- function(spatial_data, measurement_data, time_variable, mesh, log_folder = NULL) {
shiny::runApp(model_builder_shiny(spatial_data = spatial_data, measurement_data = measurement_data, time_variable = time_variable, mesh = mesh, log_folder = log_folder))
}
2 changes: 1 addition & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ reference:
contents:
- plot_interactive_map
- mesh_builder
- interactive_priors
- model_builder
- title: Parsing model output
desc: Functions to help parse model output and extract useful information
contents:
Expand Down
Loading

0 comments on commit b4300b2

Please sign in to comment.