Skip to content

Commit

Permalink
Added W argument and better explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed May 4, 2024
1 parent 73b8494 commit b84bb08
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 81 deletions.
11 changes: 8 additions & 3 deletions docs/src/guide/doublemachinelearning.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ models can take on any functional form but the final stage model is linear.
## Step 1: Initialize a Model
The DoubleMachineLearning constructor takes at least three arguments, an array of
covariates, a treatment vector, and an outcome vector. This estimator supports binary, count,
or continuous treatments and binary, count, continuous, or time to event outcomes.
or continuous treatments and binary, count, continuous, or time to event outcomes. You can
also specify confounders that you do not want to estimate the CATE for by passing a parameter
to the W argument. Otherwise, the model assumes all possible confounders are contained in X.

!!! note
Internally, the outcome and treatment models are treated as a regression since extreme
Expand All @@ -39,14 +41,17 @@ or continuous treatments and binary, count, continuous, or time to event outcome
and approximator\_neurons.
```julia
# Create some data with a binary treatment
X, T, Y = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100)
X, T, Y, W = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100), rand(100, 4)

# We could also use DataFrames
# using DataFrames
# X = DataFrame(x1=rand(100), x2=rand(100), x3=rand(100), x4=rand(100), x5=rand(100))
# T, Y = DataFrame(t=[rand()<0.4 for i in 1:100]), DataFrame(y=rand(100))
# W = DataFrame(w1=rand(100), w2=rand(100), w3=rand(100), w4=rand(100))

dml = DoubleMachineLearning(X, T, Y)
# W is optional and means there are confounders that you are not interested in estimating
# the CATE for
dml = DoubleMachineLearning(X, T, Y, W=W)
```

## Step 2: Estimate the Causal Effect
Expand Down
14 changes: 9 additions & 5 deletions docs/src/guide/doublyrobust.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,20 @@ binary, treatments and binary, count, continuous, or time to event outcomes.

!!! tip
Additional options can be specified for each type of metalearner using its keyword arguments.

```julia
# Generate data to use
X, T, Y = rand(1000, 5), [rand()<0.4 for i in 1:1000], rand(1000)
# Create some data with a binary treatment
X, T, Y, W = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100), rand(100, 4)

# We could also use DataFrames
# using DataFrames
# X = DataFrame(x1=rand(1000), x2=rand(1000), x3=rand(1000), x4=rand(1000), x5=rand(1000))
# T, Y = DataFrame(t=[rand()<0.4 for i in 1:1000]), DataFrame(y=rand(1000))
# X = DataFrame(x1=rand(100), x2=rand(100), x3=rand(100), x4=rand(100), x5=rand(100))
# T, Y = DataFrame(t=[rand()<0.4 for i in 1:100]), DataFrame(y=rand(100))
# W = DataFrame(w1=rand(100), w2=rand(100), w3=rand(100), w4=rand(100))

dr_learner = DoublyRobustLearner(X, T, Y)
# W is optional and means there are confounders that you are not interested in estimating
# the CATE for
dr_learner = DoublyRobustLearner(X, T, Y, W=W)
```

# Estimate the CATE
Expand Down
4 changes: 3 additions & 1 deletion docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
### Added
* Implemented doubly robust learner for CATE estimation [#31](https://github.com/dscolby/CausalELM.jl/issues/31)
* Provided better explanations of supported treatment and outcome variable types in the docs [#41](https://github.com/dscolby/CausalELM.jl/issues/41)
* Added support for specifying confounders, W, separate from covariates of interest, X, for double machine
learning and doubly robust estimation [39](https://github.com/dscolby/CausalELM.jl/issues/39)
### Changed
* Removed the estimate_causal_effect! call in the model constructor docstrings [#35](https://github.com/dscolby/CausalELM.jl/issues/35)
### Fixed
* Clipped probabilities between 0 and 1 for estimators that use predictions of binary variables [#36](https://github.com/dscolby/CausalELM.jl/issues/36)
* Clipped probabilities between 0 and 1 for estimators that use predictions of binary variables [#36](https://github.com/dscolby/CausalELM.jl/issues/36)

## Version [v0.5.1](https://github.com/dscolby/CausalELM.jl/releases/tag/v0.5.1) - 2024-01-15
### Added
Expand Down
78 changes: 59 additions & 19 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ For more information see:
...
# Arguments
- `X::Any`: an array or DataFrame of covariates to model the outcome.
- `X::Any`: an array or DataFrame of covariates of interest.
- `T::Any`: an vector or DataFrame of treatment statuses.
- `Y::Any`: an array or DataFrame of outcomes.
- `W::Any`: an array or dataframe of all possible confounders.
- `task::String`: either regression or classification.
- `quantity_of_interest::String`: ATE for average treatment effect or CTE for cummulative
treatment effect.
Expand Down Expand Up @@ -235,6 +236,7 @@ mutable struct DoubleMachineLearning <: CausalEstimator
X::Array{Float64}
T::Array{Float64}
Y::Array{Float64}
W::Array{Float64}
regularized::Bool
activation::Function
validation_metric::Function
Expand All @@ -249,25 +251,25 @@ mutable struct DoubleMachineLearning <: CausalEstimator
causal_effect::Float64

function DoubleMachineLearning(X::Array{<:Real}, T::Array{<:Real}, Y::Array{<:Real};
regularized=true, activation=relu, validation_metric=mse,
min_neurons=1, max_neurons=100, folds=5,
iterations=round(size(X, 1)/10),
W=X, regularized=true, activation=relu,
validation_metric=mse, min_neurons=1, max_neurons=100,
folds=5, iterations=round(size(X, 1)/10),
approximator_neurons=round(size(X, 1)/10))

new(Float64.(X), Float64.(T), Float64.(Y), regularized, activation,
new(Float64.(X), Float64.(T), Float64.(Y), Float64.(W), regularized, activation,
validation_metric, min_neurons, max_neurons, folds, iterations,
approximator_neurons, "ATE", false, 0, NaN)
end
end

function DoubleMachineLearning(X, T, Y; regularized=true, activation=relu,
function DoubleMachineLearning(X, T, Y; W=X, regularized=true, activation=relu,
validation_metric=mse, min_neurons=1, max_neurons=100, folds=5,
iterations=round(size(X, 1)/10), approximator_neurons=round(size(X, 1)/10))

# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]
X, T, Y, W = Matrix{Float64}(X), T[:, 1], Y[:, 1], Matrix{Float64}(W)

DoubleMachineLearning(X, T, Y; regularized=regularized, activation=activation,
DoubleMachineLearning(X, T, Y; W=W, regularized=regularized, activation=activation,
validation_metric=validation_metric, min_neurons=min_neurons,
max_neurons=max_neurons, folds=folds, iterations=iterations,
approximator_neurons=approximator_neurons)
Expand Down Expand Up @@ -370,6 +372,10 @@ julia> X, T, Y = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100)
julia> m1 = DoubleMachineLearning(X, T, Y)
julia> estimate_causal_effect!(m1)
0.31067439
julia> W = rand(100, 6)
julia> m2 = DoubleMachineLearning(X, T, Y, W=W)
julia> estimate_causal_effect!(m2)
0.7628583414839659
```
"""
function estimate_causal_effect!(DML::DoubleMachineLearning)
Expand Down Expand Up @@ -410,8 +416,7 @@ julia> estimate_effect!(m1)
```
"""
function estimate_effect!(DML::DoubleMachineLearning, cate=false)
X_T, Y = generate_folds(reduce(hcat, (DML.X, DML.T)), DML.Y, DML.folds)
X, T = [fl[:, 1:size(DML.X, 2)] for fl in X_T], [fl[:, size(DML.X, 2)+1] for fl in X_T]
X, T, W, Y = make_folds(DML)
predictors = cate ? Vector{RegularizedExtremeLearner}(undef, DML.folds) : Nothing
DML.causal_effect = 0

Expand All @@ -420,9 +425,11 @@ function estimate_effect!(DML::DoubleMachineLearning, cate=false)
X_train, X_test = reduce(vcat, X[1:end .!== fld]), X[fld]
Y_train, Y_test = reduce(vcat, Y[1:end .!== fld]), Y[fld]
T_train, T_test = reduce(vcat, T[1:end .!== fld]), T[fld]
W_train, W_test = reduce(vcat, W[1:end .!== fld]), W[fld]

Ỹ, T̃ = predict_residuals(DML, X_train, X_test, Y_train, Y_test, T_train, T_test)
DML.causal_effect += (reduce(hcat, (T̃, ones(length(T̃))))\Ỹ)[1]
Ỹ, T̃ = predict_residuals(DML, X_train, X_test, Y_train, Y_test, T_train, T_test,
W_train, W_test)
DML.causal_effect += (vec(sum(T̃ .* X_test, dims=2))\Ỹ)[1]

if cate # Using the weight trick to get the non-parametric CATE for an R-learner
X[fld], Y[fld] = (T̃.^2) .* X_test, (T̃.^2) .* (Ỹ./T̃)
Expand Down Expand Up @@ -461,23 +468,56 @@ julia> predict_residuals(m1, x_train, x_test, y_train, y_test, t_train, t_test)
```
"""
function predict_residuals(DML::DoubleMachineLearning, x_train, x_test, y_train, y_test,
t_train, t_test)
t_train, t_test, w_train, w_test)
V = x_train != w_train && x_test != w_test ? reduce(hcat, (x_train, w_train)) : x_train
V_test = V == x_train ? x_test : reduce(hcat, (x_test, w_test))

if DML.regularized
y = RegularizedExtremeLearner(x_train, y_train, DML.num_neurons, DML.activation)
t = RegularizedExtremeLearner(x_train, t_train, DML.num_neurons, DML.activation)
y = RegularizedExtremeLearner(V, y_train, DML.num_neurons, DML.activation)
t = RegularizedExtremeLearner(V, t_train, DML.num_neurons, DML.activation)
else
y = ExtremeLearner(x_train, y_train, DML.num_neurons, DML.activation)
t = ExtremeLearner(x_train, t_train, DML.num_neurons, DML.activation)
y = ExtremeLearner(V, y_train, DML.num_neurons, DML.activation)
t = ExtremeLearner(V, t_train, DML.num_neurons, DML.activation)
end

fit!(y); fit!(t)
y_pred = clip_if_binary(predict(y, x_test), var_type(DML.Y))
t_pred = clip_if_binary(predict(t, x_test), var_type(DML.T))
y_pred = clip_if_binary(predict(y, V_test), var_type(DML.Y))
t_pred = clip_if_binary(predict(t, V_test), var_type(DML.T))
ỹ, t̃ = y_test - y_pred, t_test - t_pred

return ỹ, t̃
end

"""
make_folds(DML)
Make folds for cross fitting for a double machine learning estimator.
This method should not be called directly.
...
# Arguments
- `DML::DoubleMachineLearning`: the DoubleMachineLearning struct to estimate the effect for.
...
Examples
```julia
julia> X, T, Y = rand(100, 5), [rand()<0.4 for i in 1:100], rand(100)
julia> m1 = DoubleMachineLearning(X, T, Y)
julia> make_folds(m1)
([[[0.8737507878554287 0.7398090999242162 … 0.45708199254415094 0.6850379444957528;
… 0.08313470408726942 0.365598632217206]]])
```
"""
function make_folds(D)
X_T_W, Y = generate_folds(reduce(hcat, (D.X, D.T, D.W)), D.Y, D.folds)
X = [fl[:, 1:size(D.X, 2)] for fl in X_T_W]
T = [fl[:, size(D.X, 2)+1] for fl in X_T_W]
W = [fl[:, size(D.X, 2)+2:end] for fl in X_T_W]

return X, T, W, Y
end

"""
moving_average(x)
Expand Down
Loading

0 comments on commit b84bb08

Please sign in to comment.