Skip to content

Commit 99e71a9

Browse files
PEM data transformation for competing risks
1 parent e4736da commit 99e71a9

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

R/PipeOpTaskSurvRegrPEM.R

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,16 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
138138
assert(max_time > data[get(event_var) == 1, min(get(time_var))],
139139
"max_time must be greater than the minimum event time.")
140140
}
141-
141+
142142
# To-Do: Extend to a more general formulation for competing risks and msm
143143
# form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")
144144
# To-Do: provide formula not as string, not via formula(...)
145145
long_data = pammtools::as_ped(data = data, formula = self$param_set$values$form, cut = cut, max_time = max_time)
146146
self$state$cut = attributes(long_data)$trafo_args$cut
147-
147+
148148
risk_scenario = attributes(long_data)$class
149-
150-
# To-Do: Does this save the information at the right location for correct prediction later on?
149+
150+
# To-Do: Does this save the information at the right location for correct prediction later on?
151151
# At which steps is this information required:
152152
# 1. prediction
153153
# 2. data transformation? Intuitively, the as_ped() function automatically detects and performs adequate transformations
@@ -158,9 +158,9 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
158158
} else {
159159
self$state$risk_scenario = 'ped'
160160
}
161-
162161

163-
162+
163+
164164
long_data = as.data.table(long_data)
165165
setnames(long_data, old = "ped_status", new = "PEM_status") #change to PEM
166166

@@ -185,9 +185,9 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
185185

186186
# extract `cut` from `state`
187187
cut = self$state$cut
188-
188+
189189
risk_scenario = self$state$risk_scenario
190-
190+
191191
time_var = task$target_names[1]
192192
event_var = task$target_names[2]
193193

@@ -196,57 +196,67 @@ PipeOpTaskSurvRegrPEM = R6Class("PipeOpTaskSurvRegrPEM",
196196
data[[time_var]] = max_time
197197

198198
status = data[[event_var]]
199-
# setting data[[event_var]] removes automatic detection of cr events of as_ped function
200-
201-
# if (risk_scenario == "ped_cr"){
202-
# long_data = as.data.table(pammtools::as_ped(data, formula = formula(self$param_set$values$form), cut = cut))
203-
# long_data = long_data |> pammtools::make_newdata(tend = unique(tend), cause = unique(cause))
204-
# }
205-
206-
207-
# requires generalization for test scenario
199+
200+
# requires generalization for test scenario
201+
# setting data[[event_var]] = 1 removes automatic detection of cr events during call of ped function
208202
# data[[event_var]] = 1
209203

210-
# update form
211-
# form = formulate(sprintf("Surv(%s, %s)", time_var, event_var), ".")
212-
213-
for (cause in unique)
214204
long_data = as.data.table(pammtools::as_ped(data, formula = formula(self$param_set$values$form), cut = cut))
215-
216205
setnames(long_data, old = "ped_status", new = "PEM_status")
217206

218207
PEM_status = id = tend = obs_times = NULL # fixing global binding notes of data.table
219208
long_data[, PEM_status := 0]
220-
# set correct id
221-
rows_per_id = nrow(long_data) / length(unique(long_data$id))
222-
long_data$obs_times = rep(time, each = rows_per_id)
223-
ids = rep(task$row_ids, each = rows_per_id)
224-
long_data[, id := ids]
225209

226-
# set correct PEM_status
227-
if (risk_scenario == 'ped_cr'){
228-
long_data$cause = rep(status, each = rows_per_id)
210+
211+
if (risk_scenario == "ped_cr"){
212+
rows_per_id = nrow(long_data) / length(unique(long_data$id))
213+
num_causes = length(unique(long_data$cause))
214+
rows_per_id_per_cause = rows_per_id / num_causes
215+
216+
# sequence of ids for every stack
217+
ids = rep(task$row_ids, each = rows_per_id_per_cause)
218+
ids = rep(ids, times = num_causes)
219+
long_data[, id := ids]
220+
221+
# To-Do: Reassign observation times for every df
222+
# long_data$obs_times = rep(rep(time, each = rows_per_id_per_cause), each = num_cause)
223+
long_data[, c("tstart", "interval") := NULL]
224+
} else {
225+
# set correct id
226+
rows_per_id = nrow(long_data) / length(unique(long_data$id))
227+
long_data$obs_times = rep(time, each = rows_per_id)
228+
ids = rep(task$row_ids, each = rows_per_id)
229+
long_data[, id := ids]
230+
231+
# starts diverging from competing risks
232+
233+
# set correct PEM status
234+
reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count
235+
status = rep(ifelse(status != 0, 1, 0), times = reps)
236+
237+
long_data[long_data[, .I[tend >= obs_times], by = id]$V1, PEM_status := status]
238+
239+
# remove some columns from 'long_data'
240+
long_data[, c("tstart", "interval", "obs_times") := NULL]
229241
}
230-
reps = long_data[, data.table(count = sum(tend >= obs_times)), by = id]$count
231-
# status = rep(status, times = reps)
232-
status = rep(ifelse(status != 0, 1, 0), times = reps)
233-
long_data[long_data[, .I[tend >= obs_times], by = id]$V1, PEM_status := status]
234242

235-
# remove some columns from `long_data`
236-
long_data[, c("tstart", "interval", "obs_times") := NULL]
237243
task_PEM = TaskRegr$new(paste0(task$id, "_PEM"), long_data,
238244
target = "PEM_status")
239245
task_PEM$set_col_roles("id", roles = "original_ids")
240246

241-
# map observed times back
242-
reps = table(long_data$id)
243-
long_data$obs_times = rep(time, each = rows_per_id)
247+
244248
# subset transformed data
245-
columns_to_keep = c("id", "obs_times", "tend", "PEM_status", "offset")
249+
if (risk_scenario == "ped_cr"){
250+
columns_to_keep = c("id", "tend", "PEM_status", "offset", "cause")
251+
} else {
252+
columns_to_keep = c("id", "obs_times", "tend", "PEM_status", "offset")
253+
# map observed times back
254+
long_data$obs_times = rep(time, each = rows_per_id)
255+
}
246256
long_data = long_data[, columns_to_keep, with = FALSE]
247-
257+
258+
# save risk_scenario in long_data to pass it on to prediction pipeline
248259
long_data$risk_scenario = risk_scenario
249-
# To-Do: return information on the risk scenario, passed on to the prediction pipeline
250260
list(task_PEM, long_data)
251261
}
252262
)

0 commit comments

Comments
 (0)