forked from business-science/modeltime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodeltime-table.R
108 lines (100 loc) · 2.58 KB
/
modeltime-table.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#' Scale forecast analysis with a Modeltime Table
#'
#' Designed to perform forecasts at scale using models created with
#' `modeltime`, `parsnip`, `workflows`, and regression modeling extensions
#' in the `tidymodels` ecosystem.
#'
#' @param ... Fitted `parsnip` model or `workflow` objects
#' @param .l A list containing fitted `parsnip` model or `workflow` objects
#'
#' @details
#'
#' `modeltime_table()`:
#'
#' 1. Creates a table of models
#' 2. Validates that all objects are models (parsnip or workflows objects) and
#' all models have been fitted (trained)
#' 3. Provides an ID and Description of the models
#'
#' `as_modeltime_table()`:
#'
#' Converts a `list` of models to a modeltime table. Useful if programatically creating
#' Modeltime Tables from models stored in a `list`.
#'
#' @examples
#' library(tidyverse)
#' library(lubridate)
#' library(timetk)
#' library(parsnip)
#' library(rsample)
#'
#' # Data
#' m750 <- m4_monthly %>% filter(id == "M750")
#'
#' # Split Data 80/20
#' splits <- initial_time_split(m750, prop = 0.9)
#'
#' # --- MODELS ---
#'
#' # Model 1: auto_arima ----
#' model_fit_arima <- arima_reg() %>%
#' set_engine(engine = "auto_arima") %>%
#' fit(value ~ date, data = training(splits))
#'
#'
#' # ---- MODELTIME TABLE ----
#'
#' # Make a Modeltime Table
#' models_tbl <- modeltime_table(
#' model_fit_arima
#' )
#'
#' # Can also convert a list of models
#' list(model_fit_arima) %>%
#' as_modeltime_table()
#'
#' # ---- CALIBRATE ----
#'
#' calibration_tbl <- models_tbl %>%
#' modeltime_calibrate(new_data = testing(splits))
#'
#' # ---- ACCURACY ----
#'
#' calibration_tbl %>%
#' modeltime_accuracy()
#'
#' # ---- FORECAST ----
#'
#' calibration_tbl %>%
#' modeltime_forecast(
#' new_data = testing(splits),
#' actual_data = m750
#' )
#'
#' @export
#' @name modeltime_table
modeltime_table <- function(...) {
as_modeltime_table(list(...))
}
#' @export
print.mdl_time_tbl <- function(x, ...) {
cat("# Modeltime Table\n")
class(x) <- class(x)[!(class(x) %in% c("mdl_time_tbl"))]
print(x, ...)
}
#' @export
#' @rdname modeltime_table
as_modeltime_table <- function(.l) {
ret <- tibble::tibble(
.model = .l
) %>%
tibble::rowid_to_column(var = ".model_id")
# CHECKS
validate_model_classes(ret, accept_classes = c("model_fit", "workflow", "mdl_time_ensemble"))
validate_models_are_trained(ret)
# CREATE MODELTIME OBJECT
ret <- ret %>%
dplyr::mutate(.model_desc = purrr::map_chr(.model, .f = get_model_description))
class(ret) <- c("mdl_time_tbl", class(ret))
return(ret)
}