Skip to content

Commit

Permalink
Merge pull request #158 from 4DModeller/Iss111/ModelOutViz
Browse files Browse the repository at this point in the history
2. Adding `model_viewer` Shiny app
  • Loading branch information
gareth-j authored Oct 9, 2023
2 parents 66fb24b + 2757208 commit 35e30b9
Show file tree
Hide file tree
Showing 18 changed files with 393 additions and 50 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Checks on the data types being passed into the `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- Ability to plot either polygon or point data on Leaflet map of `mesh_builder` tool - [PR #101](https://github.com/4DModeller/fdmr/pull/101)
- The ability to plot model predictions on a `leaflet` map in the our [Interactive priors Shiny app](https://4dmodeller.github.io/fdmr/articles/priors_app.html) - [PR #147](https://github.com/4DModeller/fdmr/pull/147)
- A new Shiny app to parse and plot INLA model output, letting users easily view model parameters and predictions on a map - [PR #158](https://github.com/4DModeller/fdmr/pull/158)

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

export(clear_caches)
export(create_prediction_field)
export(get_tmpdir)
export(get_tutorial_datapath)
export(latlong_to_utm)
export(load_tutorial_data)
export(mesh_builder)
export(mesh_checker)
export(mesh_to_spatial)
export(model_builder)
export(model_viewer)
export(numbers_only)
export(parse_model_output)
export(plot_barchart)
Expand Down
13 changes: 3 additions & 10 deletions R/model_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ create_prediction_field <- function(mesh,
stop("Invalid plot type, select from ", valid_plots)
}

data_dist <- tolower(data_dist)
valid_data_dists <- c("poisson", "gaussian")
if (!(data_dist %in% valid_data_dists)) {
stop("Invalid data type, select from ", valid_data_dists)
Expand All @@ -80,11 +81,8 @@ create_prediction_field <- function(mesh,
A_proj <- INLA::inla.spde.make.A(mesh = mesh, loc = as.matrix(xy_grid))

if (plot_type == "predicted_mean_fields") {
if (data_dist == "poisson") {
z <- base::exp(base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b))
} else {
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b)
}
z <- base::as.numeric(A_proj %*% var_a[1:mesh$n]) + base::sum(var_b)
if (data_dist == "poisson") z <- base::exp(z)
} else {
# We get an error here as we only have 265 items
# z <- var_a[1:mesh$n]
Expand All @@ -93,8 +91,3 @@ create_prediction_field <- function(mesh,

base::data.frame(x = xy_grid[, 1], y = xy_grid[, 2], z = z)
}


create_raster <- function(dataframe, crs) {
raster::rasterFromXYZ(dataframe, crs = crs)
}
8 changes: 5 additions & 3 deletions R/shiny_mapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ raster_mapping_app <- function(raster_data = NULL, polygon_data = NULL, date_for
inputId = "colour_scheme",
label = "Color Scheme",
choices = default_colours,
selected = "viridis"
),
shiny::sliderInput(
inputId = "raster_opacity",
Expand Down Expand Up @@ -137,6 +138,10 @@ raster_mapping_app <- function(raster_data = NULL, polygon_data = NULL, date_for
colours
})

colour_scheme <- shiny::reactive({
input$colour_scheme
})

raster_opacity <- shiny::reactive({
input$raster_opacity
})
Expand All @@ -145,9 +150,6 @@ raster_mapping_app <- function(raster_data = NULL, polygon_data = NULL, date_for
input$polygon_opacity
})

colour_scheme <- shiny::reactive({
input$colour_scheme
})

colour_palette <- shiny::reactive({
if (is.null(palette)) {
Expand Down
197 changes: 197 additions & 0 deletions R/shiny_modelviewer.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#' Parse inlabru model output
#'
#' @param model_output INLA model output
#' @param mesh INLA mesh
#' @param measurement_data Measurement data
#' @param data_distribution Type of data, Poisson or Gaussian
#'
#' @importFrom magrittr %>%
#'
#' @return shiny::app
#' @keywords internal
model_viewer_shiny <- function(model_output, mesh, measurement_data, data_distribution) {
busy_spinner <- get_busy_spinner()

crs <- mesh$crs$input
if (is.null(crs) || is.na(crs)) {
warning("Cannot read CRS from mesh, using default CRS = +proj=longlat +datum=WGS84")
crs <- "+proj=longlat +datum=WGS84"
}

data_distribution <- stringr::str_to_title(data_distribution)
if (!(data_distribution %in% c("Poisson", "Gaussian"))) {
stop("data_distribution must be one of Poisson or Gaussian")
}

parsed_model_output <- parse_model_output(
model_output = model_output,
measurement_data = measurement_data
)

# The comparison plotting functions expect a list of lists
parsed_modeloutput_plots <- list(parsed_model_output)

brewer_palettes <- RColorBrewer::brewer.pal.info
default_colours <- rownames(brewer_palettes[brewer_palettes$cat == "seq", ])

plot_choices <- c("Range", "Stdev", "AR(1)", "Boxplot", "Density", "DIC")

ui <- shiny::fluidPage(
busy_spinner,
shiny::headerPanel(title = "Model viewer"),
shiny::tabsetPanel(
type = "tabs",
shiny::tabPanel(
"Plots",
shiny::h2("Plot output"),
shiny::selectInput(inputId = "plot_type", label = "Plot type:", choices = plot_choices, selected = plot_choices[1]),
shiny::plotOutput(outputId = "plot_model_out")
),
shiny::tabPanel(
"Map",
shiny::fluidRow(
shiny::column(
6,
shiny::selectInput(inputId = "map_plot_type", label = "Plot type", choices = c("Predicted mean fields", "Random effect fields"), selected = "Predicted mean fields"),
shiny::selectInput(inputId = "map_data_type", label = "Data type", choices = c("Poisson", "Gaussian"), selected = data_distribution),
),
shiny::column(
6,
shiny::selectInput(
inputId = "colour_category",
label = "Palette type",
choices = c("Sequential", "Diverging", "Qualitative", "Viridis"),
selected = "Viridis"
),
shiny::selectInput(
inputId = "colour_scheme",
label = "Color Scheme",
choices = default_colours,
),
)
),
leaflet::leafletOutput(outputId = "map_out")
),
shiny::tabPanel(
"Help",
shiny::h3("Help"),
)
)
)

# Define server logic required to draw a histogram
server <- function(input, output, session) {
category_colours <- shiny::reactive({
if (input$colour_category == "Viridis") {
colours <- c("viridis", "magma", "inferno", "plasma")
} else {
palettes_mapping <- list("Sequential" = "seq", "Diverging" = "div", "Qualitative" = "qual")
chosen_cat <- palettes_mapping[input$colour_category]
colours <- rownames(subset(RColorBrewer::brewer.pal.info, category %in% chosen_cat))
}
colours
})

colour_scheme <- shiny::reactive({
input$colour_scheme
})

shiny::observe({
shiny::updateSelectInput(session, inputId = "colour_scheme", label = "Colours", choices = category_colours())
})


prediction_field <- shiny::reactive({
data_dist <- tolower(input$map_data_type)
if (input$map_plot_type == "Predicted mean fields") {
create_prediction_field(
mesh = mesh,
plot_type = "predicted_mean_fields",
data_dist = data_dist,
var_a = parsed_model_output[["mean_post"]],
var_b = parsed_model_output[["fixed_mean"]]
)
} else {
create_prediction_field(
mesh = mesh,
plot_type = "random_effect_fields",
data_dist = data_dist,
var_a = parsed_model_output[["mean_post"]]
)
}
})

z_values <- shiny::reactive({
prediction_field()[["z"]]
})

map_raster <- shiny::reactive({
raster::rasterFromXYZ(prediction_field(), crs = crs)
})

map_colours <- shiny::reactive({
leaflet::colorNumeric(palette = colour_scheme(), domain = z_values(), reverse = FALSE)
})

output$map_out <- leaflet::renderLeaflet({
if (is.null(map_raster())) {
return()
}

leaflet::leaflet() %>%
leaflet::addTiles(group = "OSM") %>%
leaflet::addRasterImage(map_raster(), colors = map_colours(), opacity = 0.9, group = "Raster") %>%
leaflet::addLegend(position = "topright", pal = map_colours(), values = z_values())
})

model_plot <- shiny::eventReactive(input$plot_type, ignoreNULL = FALSE, {
if (input$plot_type == "Range") {
return(plot_line_comparison(
data = parsed_modeloutput_plots,
to_plot = "Range for f",
title = "Range"
))
} else if (input$plot_type == "Stdev") {
return(plot_line_comparison(
data = parsed_modeloutput_plots,
to_plot = "Stdev for f",
title = "Marginal standard deviation"
))
} else if (input$plot_type == "AR(1)") {
return(plot_line_comparison(
data = parsed_modeloutput_plots,
to_plot = "GroupRho for f",
title = "AR(1)"
))
} else if (input$plot_type == "Boxplot") {
return(plot_priors_boxplot(data = parsed_modeloutput_plots))
} else if (input$plot_type == "Density") {
return(plot_priors_density(
data = parsed_modeloutput_plots,
measurement_data = measurement_data
))
} else if (input$plot_type == "DIC") {
return(plot_dic(data = parsed_modeloutput_plots))
}
})

output$plot_model_out <- shiny::renderPlot({
model_plot()
})
}

shiny::shinyApp(ui = ui, server = server)
}

#' Mesh building shiny app. Creates and visualises a mesh from some spatial data.
#'
#' @param model_output INLA model output
#' @param mesh INLA mesh
#' @param measurement_data Measurement data
#' @param data_distribution Type of data, Poisson or Gaussian
#'
#' @return shiny::app
#' @export
model_viewer <- function(model_output, mesh, measurement_data, data_distribution = "Poisson") {
shiny::runApp(model_viewer_shiny(model_output = model_output, mesh = mesh, measurement_data = measurement_data, data_distribution = data_distribution))
}
Loading

0 comments on commit 35e30b9

Please sign in to comment.