Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1. Adds map plot of model output to priors app #147

Merged
merged 19 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Clearer documentation on types expected by the `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- Checks on the data types being passed into the `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- Ability to plot either polygon or point data on Leaflet map of `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- The ability to plot model predictions on a `leaflet` map in the our [Interactive priors Shiny app](https://4dmodeller.github.io/fdmr/articles/priors_app.html) - [PR #147](https://github.com/4DModeller/fdmr/pull/147)

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Imports:
shinybusy,
curl,
promises,
fmesher
fmesher,
purrr
Suggests:
bookdown,
knitr,
Expand All @@ -64,4 +65,3 @@ Suggests:
rcmdcheck
Config/testthat/edition: 3
VignetteBuilder: knitr

3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Generated by roxygen2: do not edit by hand

export(clear_caches)
export(create_prediction_field)
export(get_tmpdir)
export(get_tutorial_datapath)
export(interactive_priors)
export(latlong_to_utm)
export(load_tutorial_data)
export(mesh_builder)
export(mesh_checker)
export(mesh_to_spatial)
export(model_builder)
export(numbers_only)
export(parse_model_output)
export(plot_barchart)
Expand Down
57 changes: 56 additions & 1 deletion R/model_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ parse_model_output_bru <- function(model_output, measurement_data) {
fitted_mean_post <- model_output$summary.fitted.values$mean[seq_len(nrow(measurement_data))]
fitted_sd_post <- model_output$summary.fitted.values$sd[seq_len(nrow(measurement_data))]

random_effect_fields <- model_output$summary.random$f$mean
mean_post <- model_output$summary.random$f$mean
sd_post <- model_output$summary.random$f$sd
fixed_mean <- model_output$summary.fixed$mean

dic <- model_output$dic$dic
pars <- model_output$marginals.hyperpar

parsed_output <- list(
fitted_mean_post = fitted_mean_post,
fitted_sd_post = fitted_sd_post,
random_effect_fields = random_effect_fields,
mean_post = mean_post,
sd_post = sd_post,
fixed_mean = fixed_mean,
Expand All @@ -43,3 +44,57 @@ parse_model_output <- function(model_output, measurement_data, model_type = "inl
return(parse_model_output_bru(model_output = model_output, measurement_data = measurement_data))
}
}


#' Create a prediction field from the parsed model output and the mesh
#'
#' @param mesh INLA mesh
#' @param plot_type Type of plot to create, "predicted_mean_fields" etc
#' @param data_dist Type of data, "poisson" etc
#' @param var_a Data for variable a, required for "predicted_mean_fields" and "random_effect_fields"
#' @param var_b Data for variable b, required for "predicted_mean_fields"
#'
#' @return data.frame
#' @export
create_prediction_field <- function(mesh,
plot_type = "predicted_mean_fields",
data_dist = "poisson",
var_a = NULL,
var_b = NULL) {
valid_plots <- c("predicted_mean_fields", "random_effect_fields")
if (!(plot_type %in% valid_plots)) {
stop("Invalid plot type, select from ", valid_plots)
}

valid_data_dists <- c("poisson", "gaussian")
if (!(data_dist %in% valid_data_dists)) {
stop("Invalid data type, select from ", valid_data_dists)
}

if (plot_type == "predicted_mean_fields" && is.null(var_b)) {
stop("var_b must be provided for predicted_mean_fields plot")
}

mod_proj <- fmesher::fm_evaluator(mesh)
xy_grid <- base::expand.grid(mod_proj$x, mod_proj$y)
A_proj <- INLA::inla.spde.make.A(mesh = mesh, loc = as.matrix(xy_grid))

if (plot_type == "predicted_mean_fields") {
if (data_dist == "poisson") {
z <- base::exp(base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b))
} else {
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b)
}
} else {
# We get an error here as we only have 265 items
# z <- var_a[1:mesh$n]
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n])
}

base::data.frame(x = xy_grid[, 1], y = xy_grid[, 2], z = z)
}


create_raster <- function(dataframe, crs) {
raster::rasterFromXYZ(dataframe, crs = crs)
}
172 changes: 149 additions & 23 deletions R/shiny_priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@
#'
#' @return shiny::app
#' @keywords internal
priors_shiny <- function(spatial_data,
measurement_data,
time_variable,
mesh,
log_folder = NULL) {
model_builder_shiny <- function(spatial_data,
measurement_data,
time_variable,
mesh,
log_folder = NULL) {
future::plan(future::multisession())

got_coords <- has_coords(spatial_data = spatial_data)
if (!got_coords) {
stop("Please make sure you have set coordinates on spatial_data using sp::coordinates.")
}

spatial_crs <- sp::proj4string(spatial_data)
mesh_crs <- mesh$crs$input

if (is.na(mesh_crs) && is.na(spatial_crs)) {
warning("Cannot read CRS from mesh or spatial_data, using default CRS = +proj=longlat +datum=WGS84")
crs <- "+proj=longlat +datum=WGS84"
} else if (is.na(mesh_crs)) {
crs <- spatial_crs
} else {
crs <- mesh_crs
}

brewer_palettes <- RColorBrewer::brewer.pal.info
default_colours <- rownames(brewer_palettes[brewer_palettes$cat == "seq", ])

# Text for priors help
prior_range_text <- "A length 2 vector, with (range0, Prange) specifying that P(ρ < ρ_0)=p_ρ,
where ρ is the spatial range of the random field."
Expand Down Expand Up @@ -128,8 +143,17 @@ priors_shiny <- function(spatial_data,
type = "tabs",
shiny::tabPanel(
"Features",
shiny::selectInput(inputId = "model_var", label = "Model variable", choices = features),
shiny::selectInput(inputId = "exposure_param", label = "Exposure param", choices = features),
shiny::fluidRow(
shiny::column(
6,
shiny::selectInput(inputId = "model_var", label = "Model variable", choices = features),
shiny::selectInput(inputId = "exposure_param", label = "Exposure (time variable)", choices = features),
),
shiny::column(
6,
shiny::selectInput(inputId = "data_dist", label = "Data distribution", choices = c("Poisson", "Gaussian")),
)
),
shiny::checkboxGroupInput(inputId = "features", label = "Features", choices = features),
shiny::checkboxInput(inputId = "f_func", label = "Add f()", value = FALSE),
shiny::actionButton(inputId = "clear", label = "Clear"),
Expand Down Expand Up @@ -157,9 +181,34 @@ priors_shiny <- function(spatial_data,
shiny::selectInput(inputId = "plot_type", label = "Plot type:", choices = plot_choices, selected = plot_choices[1]),
shiny::plotOutput(outputId = "plot_model_out")
),
shiny::tabPanel(
"Map",
shiny::fluidRow(
shiny::column(
6,
shiny::selectInput(inputId = "map_plot_type", label = "Plot type", choices = c("Predicted mean fields", "Random effect fields"), selected = "Predicted mean fields"),
shiny::selectInput(inputId = "select_run_map", label = "Select run:", choices = c())
),
shiny::column(
6,
shiny::selectInput(
inputId = "colour_category",
label = "Palette type",
choices = c("Sequential", "Diverging", "Qualitative", "Viridis"),
selected = "Viridis"
),
shiny::selectInput(
inputId = "colour_scheme",
label = "Colour Scheme",
choices = default_colours,
),
)
),
leaflet::leafletOutput(outputId = "map_out")
),
shiny::tabPanel(
"Code",
shiny::selectInput(inputId = "select_run", label = "Select run:", choices = c()),
shiny::selectInput(inputId = "select_run_code", label = "Select run:", choices = c()),
shiny::verbatimTextOutput(outputId = "code_out")
),
shiny::tabPanel(
Expand Down Expand Up @@ -201,11 +250,9 @@ priors_shiny <- function(spatial_data,
})

shiny::observe({
shiny::updateSelectInput(session = session, inputId = "select_run", choices = run_names())
})

shiny::observeEvent(input$features, {
print(paste0("You have chosen: ", input$features))
shiny::updateSelectInput(session = session, inputId = "select_run_map", choices = run_names())
shiny::updateSelectInput(session = session, inputId = "select_run_code", choices = run_names())
shiny::updateSelectInput(session, inputId = "colour_scheme", label = "Colours", choices = category_colours())
})

shiny::observeEvent(input$clear, {
Expand Down Expand Up @@ -273,20 +320,30 @@ priors_shiny <- function(spatial_data,
formula_str()
})

data_distribution <- shiny::reactive({
tolower(input$data_dist)
})

shiny::observeEvent(input$run_model, ignoreNULL = TRUE, {
exposure_param_local <- input$exposure_param
formula_local <- inla_formula()
measurement_data_local <- measurement_data

data_dist_local <- data_distribution()
family_control <- NULL
if (data_dist_local == "poisson") {
family_control <- list(link = "log")
}

promise <- promises::future_promise(
{
# Without loading INLA here we get errors
require("INLA")
inlabru::bru(formula_local,
data = measurement_data_local,
family = "poisson",
family = data_dist_local,
E = measurement_data_local[[exposure_param_local]],
control.family = list(link = "log"),
control.family = family_control,
options = list(
verbose = FALSE
)
Expand All @@ -300,8 +357,10 @@ priors_shiny <- function(spatial_data,
function(model_output) {
# Run the model
run_no(run_no() + 1)
model_vals$model_outputs[[run_no()]] <- model_output
model_vals$parsed_outputs[[run_no()]] <- parse_model_output(
run_label <- paste0("Run-", run_no())

model_vals$model_outputs[[run_label]] <- model_output
model_vals$parsed_outputs[[run_label]] <- parse_model_output(
model_output = model_output,
measurement_data = measurement_data
)
Expand All @@ -316,7 +375,6 @@ priors_shiny <- function(spatial_data,
"pg_ar1" = input$pg_ar1
)

run_label <- paste0("Run-", run_no())
model_vals$run_params[[run_label]] <- run_params

if (write_logs) {
Expand Down Expand Up @@ -355,6 +413,69 @@ priors_shiny <- function(spatial_data,
rownames = TRUE
)

category_colours <- shiny::reactive({
if (input$colour_category == "Viridis") {
colours <- c("viridis", "magma", "inferno", "plasma")
} else {
palettes_mapping <- list("Sequential" = "seq", "Diverging" = "div", "Qualitative" = "qual")
chosen_cat <- palettes_mapping[input$colour_category]
colours <- rownames(subset(RColorBrewer::brewer.pal.info, category %in% chosen_cat))
}
colours
})

colour_scheme <- shiny::reactive({
input$colour_scheme
})


prediction_field <- shiny::reactive({
if (length(model_vals$parsed_outputs) == 0) {
return()
}

data <- model_vals$parsed_outputs[[input$select_run_map]]
if (input$map_plot_type == "Predicted mean fields") {
create_prediction_field(
mesh = mesh,
plot_type = "predicted_mean_fields",
data_dist = data_distribution(),
var_a = data[["mean_post"]],
var_b = data[["fixed_mean"]]
)
} else {
create_prediction_field(
mesh = mesh,
plot_type = "random_effect_fields",
data_dist = data_distribution(),
var_a = data[["mean_post"]]
)
}
})

z_values <- shiny::reactive({
prediction_field()[["z"]]
})

map_raster <- shiny::reactive({
raster::rasterFromXYZ(prediction_field(), crs = crs)
})

map_colours <- shiny::reactive({
leaflet::colorNumeric(palette = colour_scheme(), domain = z_values(), reverse = FALSE)
})

output$map_out <- leaflet::renderLeaflet({
if (is.null(map_raster())) {
return()
}

leaflet::leaflet() %>%
leaflet::addTiles(group = "OSM") %>%
leaflet::addRasterImage(map_raster(), colors = map_colours(), opacity = 0.9, group = "Raster") %>%
leaflet::addLegend(position = "topright", pal = map_colours(), values = z_values())
})

model_plot <- shiny::eventReactive(input$plot_type, ignoreNULL = FALSE, {
if (length(model_vals$parsed_outputs) == 0) {
return()
Expand Down Expand Up @@ -401,7 +522,12 @@ priors_shiny <- function(spatial_data,
return()
}

params <- model_vals$run_params[[input$select_run]]
params <- model_vals$run_params[[input$select_run_code]]

family_control_str <- "NULL"
if (data_distribution() == "poisson") {
family_control_str <- "list(link = 'log'),"
}

paste0(
"spde <- INLA::inla.spde2.pcmatern(
Expand All @@ -416,9 +542,9 @@ priors_shiny <- function(spatial_data,
)", "\n\n",
paste0("model_output <- inlabru::bru(formula,
data = measurement_data,
family = 'poisson',
family = '", data_distribution(), "',
E = measurement_data[[", input$exposure_param, "]],
control.family = list(link = 'log'),
control.family = ", family_control_str, "
options = list(
verbose = FALSE
)
Expand All @@ -441,6 +567,6 @@ priors_shiny <- function(spatial_data,
#'
#' @return shiny::app
#' @export
interactive_priors <- function(spatial_data, measurement_data, time_variable, mesh, log_folder = NULL) {
shiny::runApp(priors_shiny(spatial_data = spatial_data, measurement_data = measurement_data, time_variable = time_variable, mesh = mesh, log_folder = log_folder))
model_builder <- function(spatial_data, measurement_data, time_variable, mesh, log_folder = NULL) {
shiny::runApp(model_builder_shiny(spatial_data = spatial_data, measurement_data = measurement_data, time_variable = time_variable, mesh = mesh, log_folder = log_folder))
}
2 changes: 1 addition & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ reference:
contents:
- plot_interactive_map
- mesh_builder
- interactive_priors
- model_builder
- title: Parsing model output
desc: Functions to help parse model output and extract useful information
contents:
Expand Down
Loading
Loading