Skip to content

Commit b208872

Browse files
bcjaegerhfricktopepo
authored
Updates for new aorsf engine for rand_forest() (#828)
* rand_forest_aorsf included in docs * make engine arg tunable * add `fit_xy()` method for `rand_forest()` so that it can error for the `aorsf` engine * Update `models.tsv` * Update engine docs for `aorsf` * re-document to update `.md` and `.Rd` * add note on case weights * update status this should be removed once the changes in aorsf are on CRAN Co-authored-by: Hannah Frick <hannah@rstudio.com> Co-authored-by: Max Kuhn <mxkuhn@gmail.com>
1 parent bdc2854 commit b208872

File tree

9 files changed

+264
-1
lines changed

9 files changed

+264
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.2.9003
3+
Version: 1.0.2.9004
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@rstudio.com", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@rstudio.com", role = "aut"),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ S3method(fit,model_spec)
99
S3method(fit_xy,decision_tree)
1010
S3method(fit_xy,gen_additive_mod)
1111
S3method(fit_xy,model_spec)
12+
S3method(fit_xy,rand_forest)
1213
S3method(glance,model_fit)
1314
S3method(has_multi_predict,default)
1415
S3method(has_multi_predict,model_fit)

R/rand_forest.R

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,25 @@ check_args.rand_forest <- function(object) {
163163
# move translate checks here?
164164
invisible(object)
165165
}
166+
167+
# ------------------------------------------------------------------------------
168+
169+
#' @export
170+
fit_xy.rand_forest <- function(object,
171+
x,
172+
y,
173+
case_weights = NULL,
174+
control = parsnip::control_parsnip(),
175+
...) {
176+
177+
if (object$mode == "censored regression" && object$engine == "aorsf") {
178+
# CRAN aorsf::orsf() requires two variables on the left-hand side of the formula,
179+
# either in as `Surv(time, status) ~ .` or as `time + status ~ .`
180+
# see https://github.com/ropensci/aorsf/issues/11
181+
rlang::abort("For the `'aorsf'` engine, please use the formula interface via `fit()`.")
182+
}
183+
184+
# call parsnip::fit_xy.model_spec()
185+
res <- NextMethod()
186+
res
187+
}

R/rand_forest_aorsf.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#' Oblique random survival forests via aorsf
2+
#'
3+
#' [aorsf::orsf()] fits a model that creates a large number of decision
4+
#' trees, each de-correlated from the others. The final prediction uses all
5+
#' predictions from the individual trees and combines them.
6+
#'
7+
#' @includeRmd man/rmd/rand_forest_aorsf.md details
8+
#'
9+
#' @name details_rand_forest_aorsf
10+
#' @keywords internal
11+
NULL
12+
13+
# See inst/README-DOCS.md for a description of how these files are processed

R/tunable.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ partykit_engine_args <-
159159
component_id = "engine"
160160
)
161161

162+
aorsf_engine_args <-
163+
tibble::tibble(
164+
name = c(
165+
"split_min_stat"
166+
),
167+
call_info = list(
168+
list(pkg = "dials", fun = "conditional_min_criterion")
169+
),
170+
source = "model_spec",
171+
component = "rand_forest",
172+
component_id = "engine"
173+
)
162174

163175
earth_engine_args <-
164176
tibble::tibble(
@@ -284,6 +296,8 @@ tunable_rand_forest <- function(x, ...) {
284296
res <- add_engine_parameters(res, randomForest_engine_args)
285297
} else if (x$engine == "partykit") {
286298
res <- add_engine_parameters(res, partykit_engine_args)
299+
} else if (x$engine == "aorsf") {
300+
res <- add_engine_parameters(res, aorsf_engine_args)
287301
}
288302
res
289303
}

inst/models.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"poisson_reg" "regression" "zeroinfl" "poissonreg"
106106
"proportional_hazards" "censored regression" "glmnet" "censored"
107107
"proportional_hazards" "censored regression" "survival" "censored"
108+
"rand_forest" "censored regression" "aorsf" "censored"
108109
"rand_forest" "censored regression" "partykit" "censored"
109110
"rand_forest" "classification" "h2o" "agua"
110111
"rand_forest" "classification" "partykit" "bonsai"

man/details_rand_forest_aorsf.Rd

Lines changed: 79 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/rand_forest_aorsf.Rmd

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
```{r, child = "aaa.Rmd", include = FALSE}
2+
```
3+
4+
`r descr_models("rand_forest", "aorsf")`
5+
6+
## Tuning Parameters
7+
8+
```{r aorsf-param-info, echo = FALSE}
9+
defaults <-
10+
tibble::tibble(parsnip = c("trees", "min_n", "mtry"),
11+
default = c("500L", "5L", "ceiling(sqrt(n_predictors))"))
12+
13+
param <-
14+
rand_forest() %>%
15+
set_engine("aorsf") %>%
16+
set_mode("censored regression") %>%
17+
make_parameter_list(defaults) %>%
18+
distinct()
19+
```
20+
21+
This model has `r nrow(param)` tuning parameters:
22+
23+
```{r aorsf-param-list, echo = FALSE, results = "asis"}
24+
param$item
25+
```
26+
27+
Additionally, this model has one engine-specific tuning parameter:
28+
29+
* `split_min_stat`: Minimum test statistic required to split a node. Default is `3.841459` for the log-rank test, which is roughly a p-value of 0.05.
30+
31+
32+
# Translation from parsnip to the original package (censored regression)
33+
34+
`r uses_extension("rand_forest", "aorsf", "censored regression")`
35+
36+
```{r aorsf-creg}
37+
library(censored)
38+
39+
rand_forest() %>%
40+
set_engine("aorsf") %>%
41+
set_mode("censored regression") %>%
42+
translate()
43+
```
44+
45+
## Preprocessing requirements
46+
47+
```{r child = "template-tree-split-factors.Rmd"}
48+
```
49+
50+
## Case weights
51+
52+
```{r child = "template-uses-case-weights.Rmd"}
53+
```
54+
55+
## Other details
56+
57+
Predictions of survival probability at a time exceeding the maximum observed event time are the predicted survival probability at the maximum observed time in the training data.
58+
59+
## References
60+
61+
- Jaeger BC, Long DL, Long DM, Sims M, Szychowski JM, Min YI, Mcclure LA, Howard G, Simon N. Oblique random survival forests. Annals of applied statistics 2019 Sep; 13(3):1847-83. DOI: 10.1214/19-AOAS1261
62+
63+
- Jaeger BC, Welden S, Lenoir K, Pajewski NM. aorsf: An R package for supervised learning using the oblique random survival forest. Journal of Open Source Software 2022, 7(77), 1 4705. https://doi.org/10.21105/joss.04705.
64+
65+
- Jaeger BC, Welden S, Lenoir K, Speiser JL, Segar MW, Pandey A, Pajewski NM. Accelerated and interpretable oblique random survival forests. arXiv e-prints 2022 Aug; arXiv-2208. URL: https://arxiv.org/abs/2208.01129

man/rmd/rand_forest_aorsf.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
2+
3+
4+
For this engine, there is a single mode: censored regression
5+
6+
## Tuning Parameters
7+
8+
9+
10+
This model has 3 tuning parameters:
11+
12+
- `trees`: # Trees (type: integer, default: 500L)
13+
14+
- `min_n`: Minimal Node Size (type: integer, default: 5L)
15+
16+
- `mtry`: # Randomly Selected Predictors (type: integer, default: ceiling(sqrt(n_predictors)))
17+
18+
Additionally, this model has one engine-specific tuning parameter:
19+
20+
* `split_min_stat`: Minimum test statistic required to split a node. Default is `3.841459` for the log-rank test, which is roughly a p-value of 0.05.
21+
22+
23+
# Translation from parsnip to the original package (censored regression)
24+
25+
The **censored** extension package is required to fit this model.
26+
27+
28+
```r
29+
library(censored)
30+
31+
rand_forest() %>%
32+
set_engine("aorsf") %>%
33+
set_mode("censored regression") %>%
34+
translate()
35+
```
36+
37+
```
38+
## Random Forest Model Specification (censored regression)
39+
##
40+
## Computational engine: aorsf
41+
##
42+
## Model fit template:
43+
## aorsf::orsf(formula = missing_arg(), data = missing_arg(), weights = missing_arg())
44+
```
45+
46+
## Preprocessing requirements
47+
48+
49+
This engine does not require any special encoding of the predictors. Categorical predictors can be partitioned into groups of factor levels (e.g. `{a, c}` vs `{b, d}`) when splitting at a node. Dummy variables are not required for this model.
50+
51+
## Case weights
52+
53+
54+
This model can utilize case weights during model fitting. To use them, see the documentation in [case_weights] and the examples on `tidymodels.org`.
55+
56+
The `fit()` and `fit_xy()` arguments have arguments called `case_weights` that expect vectors of case weights.
57+
58+
## Other details
59+
60+
Predictions of survival probability at a time exceeding the maximum observed event time are the predicted survival probability at the maximum observed time in the training data.
61+
62+
## References
63+
64+
- Jaeger BC, Long DL, Long DM, Sims M, Szychowski JM, Min YI, Mcclure LA, Howard G, Simon N. Oblique random survival forests. Annals of applied statistics 2019 Sep; 13(3):1847-83. DOI: 10.1214/19-AOAS1261
65+
66+
- Jaeger BC, Welden S, Lenoir K, Pajewski NM. aorsf: An R package for supervised learning using the oblique random survival forest. Journal of Open Source Software 2022, 7(77), 1 4705. https://doi.org/10.21105/joss.04705.
67+
68+
- Jaeger BC, Welden S, Lenoir K, Speiser JL, Segar MW, Pandey A, Pajewski NM. Accelerated and interpretable oblique random survival forests. arXiv e-prints 2022 Aug; arXiv-2208. URL: https://arxiv.org/abs/2208.01129

0 commit comments

Comments
 (0)