Skip to content

Commit

Permalink
add constant argument to transform_fun as per #79
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinsimpson committed Mar 11, 2024
1 parent 4eb66c4 commit 0cd917f
Showing 1 changed file with 96 additions and 12 deletions.
108 changes: 96 additions & 12 deletions R/utililties.R
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,7 @@ vars_from_label <- function(label) {
#'
#' @param object an object to apply the transform function to.
#' @param fun the function to apply.
#' @param constant numeric; a constant to apply before transformation.
#' @param ... additional arguments passed to methods.
#' @param column character; for the `"tbl_df"` method, which column to
#' transform.
Expand All @@ -1134,7 +1135,16 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect any_of
`transform_fun.smooth_estimates` <- function(object, fun = NULL, ...) {
`transform_fun.smooth_estimates` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(any_of(c("est", "lower_ci", "upper_ci",
".estimate", ".upper_ci", ".lower_ci")),
.fns = \(x) x + constant))
}
## If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
Expand All @@ -1156,14 +1166,21 @@ vars_from_label <- function(label) {
#' @rdname transform_fun
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.smooth_samples` <- function(object, fun = NULL, ...) {
## If fun supplied, use it to transform est and the upper and lower interval
#' @importFrom tidyselect all_of any_of
`transform_fun.smooth_samples` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform value
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c(".value")),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform value
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
object,
across(all_of("value"),
across(all_of(".value"),
.fns = fun
)
)
Expand All @@ -1176,7 +1193,15 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.mgcv_smooth` <- function(object, fun = NULL, ...) {
`transform_fun.mgcv_smooth` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c(".estimate", ".upper_ci", ".lower_ci")),
.fns = \(x) x + constant))
}
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand All @@ -1194,7 +1219,15 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.evaluated_parametric_term` <- function(object, fun = NULL, ...) {
`transform_fun.evaluated_parametric_term` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(c("est", "lower", "upper")),
.fns = \(x) x + constant))
}
## If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
Expand All @@ -1212,9 +1245,17 @@ vars_from_label <- function(label) {
#' @rdname transform_fun
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.parametric_effects` <- function(object, fun = NULL, ...) {
## If fun supplied, use it to transform est and the upper and lower interval
#' @importFrom tidyselect all_of any_of
`transform_fun.parametric_effects` <- function(object, fun = NULL,
constant = NULL, ...) {
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(any_of(c(".partial", ".lower_ci", ".upper_ci")),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand All @@ -1232,11 +1273,19 @@ vars_from_label <- function(label) {
#' @export
#' @importFrom dplyr mutate across
#' @importFrom tidyselect all_of
`transform_fun.tbl_df` <- function(object, fun = NULL, column = NULL, ...) {
`transform_fun.tbl_df` <- function(object, fun = NULL, column = NULL,
constant = NULL, ...) {
if (is.null(column)) {
stop("'column' to modify must be supplied.")
}
## If fun supplied, use it to transform est and the upper and lower interval
# If constant supplied, use it to transform est and the upper and lower
# interval
if (!is.null(constant)) {
object <- object |>
mutate(across(all_of(column),
.fns = \(x) x + constant))
}
# If fun supplied, use it to transform est and the upper and lower interval
if (!is.null(fun)) {
fun <- match.fun(fun)
object <- mutate(
Expand Down Expand Up @@ -1765,3 +1814,38 @@ reclass_scam_smooth <- function(smooth) {

sm_vars
}

#' Extract the model constant term
#'
#' Extracts the model constant term, the model intercept, from a fitted model
#' object.
#'
#' @param model a fitted model for which a `coef()` method exists
#'
#' @export
#' @importFrom stats coef
#' @examples
#' \dontshow{
#' op <- options(digits = 4)
#' }
#' load_mgcv()
#'
#' # simulate a small example
#' df <- data_sim("eg1")
#'
#' # fit the GAM
#' m <- gam(y ~ s(x0) + s(x1) + s(x2) + s(x3), data = df, method = "REML")
#'
#' # extract the estimate of the constant term
#' model_constant(m)
#' # same as coef(m)[1L]
#' coef(m)[1L]
#'
#' \dontshow{
#' options(op)
#' }
`model_constant` <- function(model) {
b <- coef(model)
b[1L] |>
unname()
}

0 comments on commit 0cd917f

Please sign in to comment.