Skip to content

Commit

Permalink
Shuffled data in DML, DRE, and RLearner constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Jul 2, 2024
1 parent c97ea4e commit c6139e1
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 33 deletions.
4 changes: 4 additions & 0 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ function DoubleMachineLearning(
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]

# Shuffle data with random indices
indices = shuffle(1:length(Y))
X, T, Y = X[indices, :], T[indices], Y[indices]

task = var_type(Y) isa Binary ? "classification" : "regression"

Expand Down
11 changes: 9 additions & 2 deletions src/metalearners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ function RLearner(
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]

# Shuffle data with random indices
indices = shuffle(1:length(Y))
X, T, Y = X[indices, :], T[indices], Y[indices]

task = var_type(Y) isa Binary ? "classification" : "regression"

return RLearner(
Expand Down Expand Up @@ -371,11 +375,14 @@ function DoublyRobustLearner(
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)),
folds::Integer=5,
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]

# Shuffle data with random indices
indices = shuffle(1:length(Y))
X, T, Y = X[indices, :], T[indices], Y[indices]

task = var_type(Y) isa Binary ? "classification" : "regression"

return DoublyRobustLearner(
Expand All @@ -391,7 +398,7 @@ function DoublyRobustLearner(
num_feats,
num_neurons,
fill(NaN, size(T, 1)),
folds,
2,
)
end

Expand Down
6 changes: 3 additions & 3 deletions src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Random: shuffle

"""Abstract type used to dispatch risk_ratio on nonbinary treatments"""
abstract type Nonbinary end

Expand Down Expand Up @@ -185,9 +187,7 @@ function generate_folds(X, T, Y, folds)
msg = """the number of folds must be less than the number of observations"""
n = length(Y)

if folds >= n
throw(ArgumentError(msg))
end
if folds >= n throw(ArgumentError(msg))end

x_folds = Array{Array{Float64, 2}}(undef, folds)
t_folds = Array{Array{Float64, 1}}(undef, folds)
Expand Down
114 changes: 86 additions & 28 deletions testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,51 +134,70 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DoubleMachineLearning([0.15384615384615385 0.1258211589371507 … 0.0 1.0; 0.6923076923076923 0.1441562898323365 … 0.0 1.0; … ; 0.41025641025641024 0.24039121482498285 … 0.0 1.0; 0.07692307692307693 0.117891459947053630.0 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-3300.0, 61010.0, 8849.0, -6013.0, -2375.0, -11000.0, -16901.0, 1000.0, 0.0, 6400.0 … -1436.0, 4500.0, 34739.0, -750.0, 40000.0, 172.0, 836.0, 6150.0, 14499.0, -5400.0], \"ATE\", false, \"regression\", CausalELM.relu, 9915, 100, 6, 24, NaN, 5)"
"DoubleMachineLearning([0.46153846153846156 0.21734974017060496 … 0.0 1.0; 0.5897435897435898 0.05495636827139916 … 0.0 0.0; … ; 0.02564102564102564 0.11648200804000393 … 0.0 1.0; 0.6410256410256411 0.224115109324443561.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [18800.0, 500.0, 5600.0, 62535.0, -5100.0, 9145.0, 25999.0, 0.0, 2150.0, 5000.0 … 189000.0, 14400.0, 240.0, 249.0, -928.0, 107750.0, 0.0, 114335.0, 10500.0, 8849.0], \"ATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 24, NaN, 5)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dml = DoubleMachineLearning(covariates, treatment, outcome)"
"dml = DoubleMachineLearning(covariates, treatment, outcome, activation=swish)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RLearner([0.15384615384615385 0.1258211589371507 … 0.0 1.0; 0.6923076923076923 0.1441562898323365 … 0.0 1.0; … ; 0.41025641025641024 0.24039121482498285 … 0.0 1.0; 0.07692307692307693 0.11789145994705363 … 0.0 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-3300.0, 61010.0, 8849.0, -6013.0, -2375.0, -11000.0, -16901.0, 1000.0, 0.0, 6400.0 … -1436.0, 4500.0, 34739.0, -750.0, 40000.0, 172.0, 836.0, 6150.0, 14499.0, -5400.0], \"CATE\", false, \"regression\", CausalELM.relu, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)"
"RLearner([0.1282051282051282 0.11108932248259633 … 0.0 1.0; 0.6923076923076923 0.1186881066771252 … 0.0 1.0; … ; 0.20512820512820512 0.07500735366212374 … 0.0 1.0; 0.41025641025641024 0.11607755662319835 … 0.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-8900.0, -4800.0, 27500.0, -1650.0, -2000.0, 30740.0, 2859.0, -2150.0, 0.0, 11599.0 … 43599.0, -7200.0, 23309.0, 8774.0, 6500.0, -400.0, 22700.0, 7399.0, -5400.0, 1499.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"r_learner = RLearner(covariates, treatment, outcome)"
"r_learner = RLearner(covariates, treatment, outcome, activation=swish)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8823.500636214852"
"DoublyRobustLearner([0.6410256410256411 0.1558486126090793 … 0.0 1.0; 0.23076923076923078 0.06633003235611334 … 0.0 0.0; … ; 0.6153846153846154 0.06843808216491813 … 0.0 0.0; 0.7435897435897436 0.2292994411216786 … 0.0 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 … 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [100.0, 0.0, 14350.0, 4600.0, 84248.0, -1800.0, 1020.0, 2280.0, 14699.0, 881.0 … 367.0, -5600.0, -5400.0, 5674.0, 12211.0, 32500.0, 1152.0, 2182.0, 0.0, 330.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dre = DoublyRobustLearner(covariates, treatment, outcome, activation=swish)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8804.269472283213"
]
},
"metadata": {},
Expand All @@ -191,33 +210,33 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9915-element Vector{Float64}:\n",
" 4085.5839404080925\n",
" 15773.51315113084\n",
" 38901.80802040522\n",
" 3825.3848781869037\n",
" 11964.765726429632\n",
" 26765.991729444253\n",
" 16975.200557225948\n",
" 7452.263104809677\n",
" 1115.323329175054\n",
" 12363.569530065344\n",
" 1033.275328379404\n",
" 3897.6188907530145\n",
" 27094.516749605616\n",
" 8327.283149032586\n",
" 6781.702531736929\n",
" 50200.72898282418\n",
" 618.590315821573\n",
" 6647.26749174192\n",
" 4325.783318029439\n",
" 16617.629336705013\n",
"\n",
" 11433.6140084084\n",
" 4800.764220118784\n",
" 2925.4867379282705\n",
" 39714.813007228164\n",
" 1647.2470272172372\n",
" 10061.73821939839\n",
" 14687.816324667367\n",
" 17992.791169984106\n",
" 434.7500362608628"
" 25103.616301146572\n",
" 40417.24987461999\n",
" 6976.012498684692\n",
" 8869.662932387795\n",
" -1030.3323016387612\n",
" 4912.327776140574\n",
" 2840.9932292653525\n",
" 3323.126233753097\n",
" 21356.54170795394"
]
},
"metadata": {},
Expand All @@ -228,6 +247,45 @@
"estimate_causal_effect!(r_learner)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9915-element Vector{Float64}:\n",
" 8651.259686332123\n",
" 2763.6426805062965\n",
" 4281.08620983512\n",
" 6996.106017505121\n",
" 37295.1224689869\n",
" 3425.2628336886887\n",
" 7259.653364085303\n",
" 3931.840707261489\n",
" 3390.6489181977217\n",
" 396.19186564028234\n",
"\n",
" 13778.740930336877\n",
" 13824.272936865971\n",
" 770.8718719469387\n",
" 5661.227928432385\n",
" 10218.778717409776\n",
" 3707.70741363045\n",
" 2089.690748271022\n",
" 3767.843767168565\n",
" 17841.535784697724"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"estimate_causal_effect!(dre)"
]
},
{
"cell_type": "code",
"execution_count": 9,
Expand Down

0 comments on commit c6139e1

Please sign in to comment.