Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/available_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ get_forecast_counts <- function(data,
collapse = c("quantile", "sample_id")) {

data <- as_forecast(data)
forecast_unit <- attr(data, "forecast_unit")
forecast_unit <- get_forecast_unit(data)
data <- na.omit(data)

if (is.null(by)) {
Expand Down
15 changes: 1 addition & 14 deletions R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -152,28 +152,15 @@ get_metrics <- function(scores) {
#' [get_protected_columns()] as well as the names of the metrics that were
#' specified during scoring, if any.
#' @inheritParams validate_forecast
#' @param check_conflict Whether or not to check whether there is a conflict
#' between a stored attribute and the inferred forecast unit. When you create
#' a forecast object, the forecast unit is stored as an attribute. If you
#' later change the columns of the data, the forecast unit as inferred from the
#' data might change compared to the stored attribute. Should this result in a
#' warning? Defaults to FALSE.
#' @return A character vector with the column names that define the unit of
#' a single forecast
#' @export
#' @keywords check-forecasts
get_forecast_unit <- function(data, check_conflict = FALSE) {
get_forecast_unit <- function(data) {
# check whether there is a conflict in the forecast_unit and if so warn
protected_columns <- get_protected_columns(data)
protected_columns <- c(protected_columns, attr(data, "metric_names"))

forecast_unit <- setdiff(colnames(data), unique(protected_columns))

conflict <- check_attribute_conflict(data, "forecast_unit", forecast_unit)
if (check_conflict && !is.logical(conflict)) {
warning(conflict)
}

return(forecast_unit)
}

Expand Down
10 changes: 7 additions & 3 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ score.default <- function(data, ...) {
}

#' @importFrom stats na.omit
#' @importFrom data.table setattr
#' @rdname score
#' @export
score.forecast_binary <- function(data, metrics = rules_binary(), ...) {
Expand All @@ -97,6 +98,7 @@ score.forecast_binary <- function(data, metrics = rules_binary(), ...) {

#' @importFrom Metrics se ae ape
#' @importFrom stats na.omit
#' @importFrom data.table setattr
#' @rdname score
#' @export
score.forecast_point <- function(data, metrics = rules_point(), ...) {
Expand All @@ -115,12 +117,13 @@ score.forecast_point <- function(data, metrics = rules_point(), ...) {
}

#' @importFrom stats na.omit
#' @importFrom data.table setattr
#' @rdname score
#' @export
score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
forecast_unit <- get_forecast_unit(data)
metrics <- validate_metrics(metrics)

# transpose the forecasts that belong to the same forecast unit
Expand Down Expand Up @@ -151,14 +154,15 @@ score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
return(data[])
}


#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table rbindlist %like%
#' @importFrom data.table `:=` as.data.table rbindlist %like% setattr
#' @rdname score
#' @export
score.forecast_quantile <- function(data, metrics = rules_quantile(), ...) {
data <- validate_forecast(data)
data <- na.omit(data)
forecast_unit <- attr(data, "forecast_unit")
forecast_unit <- get_forecast_unit(data)
metrics <- validate_metrics(metrics)

# transpose the forecasts that belong to the same forecast unit
Expand Down
12 changes: 3 additions & 9 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ validate_forecast.forecast_sample <- function(data, ...) {
#' @inheritParams get_forecast_counts
#' @return returns the input, with a few new attributes that hold additional
#' information, messages and warnings
#' @importFrom data.table ':=' is.data.table setattr
#' @importFrom data.table ':=' is.data.table
#' @importFrom checkmate assert_data_table
#' @export
#' @keywords internal_input_check
Expand All @@ -139,20 +139,14 @@ validate_general <- function(data) {
assert(check_data_columns(data))
data <- assure_model_column(data)

# assign forecast type and unit as an attribute and make sure there is no clash
forecast_type <- get_forecast_type(data)
setattr(data, "forecast_type", forecast_type)

forecast_unit <- get_forecast_unit(data, check_conflict = TRUE)
setattr(data, "forecast_unit", forecast_unit)

# check that there aren't any duplicated forecasts
forecast_unit <- get_forecast_unit(data)
assert(check_duplicates(data, forecast_unit = forecast_unit))

# check that the number of forecasts per sample / quantile is the same
number_quantiles_samples <- check_number_per_forecast(data, forecast_unit)
if (!is.logical(number_quantiles_samples)) {
setattr(data, "warnings", number_quantiles_samples)
warning(number_quantiles_samples)
}

# check whether there are any NA values
Expand Down
9 changes: 1 addition & 8 deletions man/get_forecast_unit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 0 additions & 11 deletions tests/testthat/test-get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,6 @@ test_that("get_forecast_unit() works as expected", {
c("location", "target_end_date", "target_type", "location_name",
"forecast_date", "model", "horizon")
)

data <- as_forecast(na.omit(example_quantile))
ex <- data[, location := NULL]
expect_warning(
get_forecast_unit(ex, check_conflict = TRUE),
"Object has an attribute `forecast_unit`, but it looks different from what's expected based on the data.
Existing: forecast_date, horizon, location, location_name, model, target_end_date, target_type
Expected: forecast_date, horizon, location_name, model, target_end_date, target_type
Running `as_forecast()` again might solve the problem",
fixed = TRUE
)
})


Expand Down