Skip to content

Commit

Permalink
Cox plot (#148)
Browse files Browse the repository at this point in the history
* add default plot for cox model

* update NEWS.md

* making sure mlr3proba doesn't break the CI

* trying again

---------

Co-authored-by: be-marc <marcbecker@posteo.de>
  • Loading branch information
bblodfon and be-marc authored Jul 2, 2024
1 parent 4e2d284 commit aa8a86a
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 2 deletions.
9 changes: 8 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ Suggests:
stats,
testthat (>= 3.0.0),
vdiffr (>= 1.0.2),
xgboost
xgboost,
survminer,
mlr3proba (>= 0.6.3)
Remotes:
mlr-org/mlr3proba
Additional_repositories:
https://mlr-org.r-universe.dev
Config/testthat/edition: 3
Config/testthat/parallel: true
Encoding: UTF-8
Expand All @@ -79,6 +85,7 @@ Collate:
'LearnerRegrCVGlmnet.R'
'LearnerRegrGlmnet.R'
'LearnerRegrRpart.R'
'LearnerSurvCoxPH.R'
'OptimInstanceBatchSingleCrit.R'
'Prediction.R'
'PredictionClassif.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ S3method(autoplot,LearnerRegr)
S3method(autoplot,LearnerRegrCVGlmnet)
S3method(autoplot,LearnerRegrGlmnet)
S3method(autoplot,LearnerRegrRpart)
S3method(autoplot,LearnerSurvCoxPH)
S3method(autoplot,OptimInstanceBatchSingleCrit)
S3method(autoplot,PredictionClassif)
S3method(autoplot,PredictionClust)
Expand All @@ -41,6 +42,7 @@ S3method(plot,LearnerClassifRpart)
S3method(plot,LearnerRegrCVGlmnet)
S3method(plot,LearnerRegrGlmnet)
S3method(plot,LearnerRegrRpart)
S3method(plot,LearnerSurvCoxPH)
S3method(plot,PredictionClassif)
S3method(plot,PredictionRegr)
S3method(plot,ResampleResult)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# mlr3viz (development version)

- Add plot for `LearnerSurvCoxPH`.

# mlr3viz 0.9.0

- Work with new bbotk 0.9.0 and mlr3tuning 0.21.0

- Add plots for `EnsembleFSResult` object.

# mlr3viz 0.8.0
Expand Down
49 changes: 49 additions & 0 deletions R/LearnerSurvCoxPH.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#' @title Plots for Cox Proportional Hazards Learner
#'
#' @description
#' Visualizations for [mlr3proba::LearnerSurvCoxPH].
#'
#' The argument `type` controls what kind of plot is drawn.
#' The only possible choice right now is `"ggforest"` (default) which is a
#' Forest Plot, using [ggforest][survminer::ggforest()].
#' This plot displays the estimated hazard ratios (HRs) and their confidence
#' intervals (CIs) for different variables included in the (trained) model.
#'
#' @param object ([mlr3proba::LearnerSurvCoxPH]).
#'
#' @template param_type
#' @param ... Additional parameters passed down to `ggforest`.
#'
#' @return [ggplot2::ggplot()].
#'
#' @export
#' @examples
#' \donttest{
#' if (requireNamespace("mlr3proba")) {
#' library(mlr3proba)
#' library(mlr3viz)
#'
#' task = tsk("lung")
#' learner = lrn("surv.coxph")
#' learner$train(task)
#' autoplot(learner)
#' }
#' }
autoplot.LearnerSurvCoxPH = function(object, type = "ggforest", ...) {
assert_class(object, classes = "LearnerSurvCoxPH", null.ok = FALSE)
assert_has_model(object)

switch(type,
"ggforest" = {
require_namespaces("survminer")
suppressWarnings(survminer::ggforest(object$model, ...))
},

stopf("Unknown plot type '%s'", type)
)
}

#' @export
plot.LearnerSurvCoxPH = function(x, ...) {
print(autoplot(x, ...))
}
41 changes: 41 additions & 0 deletions man/autoplot.LearnerSurvCoxPH.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit aa8a86a

Please sign in to comment.