Description
fit()
ting a single model vs resampling a model fit via fit_resamples()
has a nice parallelism to it with the usual single-model approach:
# single model:
fit(workflow, data)
# multiple models:
fit_resamples(workflow, resamples(data))
Ideally, the two-stage approach in causal inference could have the same ring to it.
library(tidymodels)
library(causalworkshop)
library(propensity)
net_data <- net_data %>% mutate(net = factor(net, levels = c("TRUE", "FALSE")))
Defining workflows:
outcome_wf <-
workflow(
malaria_risk ~ net,
linear_reg()
) %>%
add_case_weights(wts)
propensity_wf <-
workflow(
net ~ income + health + temperature,
logistic_reg()
)
Note that there’s an add_case_weights()
step for outcome_wf()
that can’t happen until propensity_wf()
can generate predictions.
With no changes made to tidymodels, the code for a fit to one dataset looks like something like:
net_data_wts <-
fit(propensity_wf, net_data) %>%
augment(net_data) %>%
mutate(
wts = wt_ate(.pred_TRUE, net, .treated = "TRUE")
)
results <-
outcome_wf %>%
fit(
data = net_data_wts %>% mutate(wts = importance_weights(wts))
) %>%
tidy()
For the resampled fit, we definitely need a helper that bridges the two calls to fit_resamples()
by mutating propensity weights onto the assessment set underlying the rsplit. In this example, I call it weight_propensity()
:
results <-
fit_resamples(
propensity_wf,
resamples = bootstraps(net_data),
control = control_resamples(extract = identity)
) %>%
weight_propensity(wt_ate) %>%
fit_resamples(
outcome_wf,
resamples = .,
control = control_resamples(extract = tidy)
)
The helper (sans error checking) could look something like:
EDIT: Out of date, see linked PR
# a function that takes in a resample fit object and outputs a modified
# version of that object where the training data underlying each rsplit
# is augmented with propensity weights for each element of the analysis set.
# this serves as the "bridge" between two calls to
# `fit_resamples()` (or `tune_*()`) in a causal workflow, the
# first being for the propensity model and the second for the outcome model.
# `tune_results` must have been executed with option `extract = identity`.
weight_propensity <- function(tune_results, wt_fn) {
for (resample in seq_along(tune_results$splits)) {
tune_results$splits[[resample]] <-
augment_split(
tune_results$splits[[resample]],
tune_results$.extracts[[resample]]$.extracts[[1]],
wt_fn = wt_fn,
outcome_name = outcome_names(tune_results)
)
}
tibble::new_tibble(
tune_results[, c("splits", "id")],
!!!attr(tune_results, "rset_info")$att,
class = c(attr(tune_results, "rset_info")$att$class, "rset")
)
}
augment_split <- function(split, workflow, wt_fn, outcome_name) {
d <- analysis(split)
d <- vctrs::vec_cbind(d, predict(workflow, d, type = "prob"))
d <- vctrs::vec_slice(d, !duplicated(d$id))
model_fit <- extract_fit_parsnip(workflow)
lvls <- model_fit$lvl
event_lvl <- lvls[1]
preds <- d[[paste0(".pred_", event_lvl)]]
split[["data"]][d$id, "wts"] <-
importance_weights(wt_fn(preds, d[[outcome_name]], .treated = event_lvl))
split
}
Questions:
-
What’s a good name for
weight_propensity()
? Is there something more compatible with the analogous(?) procedure in survival analysis? -
Do we want that function to be able to be used in both the single-fit and resampled-fit setting to aid with that parallelism? We could make a method that takes in data, a weighting function, and a model fit to make the single-fit setting feel more like the resampled-fit setting, a la:
result <-
fit(
propensity_wf,
net_data
) %>%
weight_propensity(net_data, wt_ate, .) %>%
fit(
outcome_wf %>% add_case_weights(wts),
data = .
) %>%
tidy()
I'd propose we do include that parsnip/workflows counterpart. That helper (and probably the generic?) ought to live in parsnip/workflows, if so.