From c6139e1cb825c8f8432291fcb612a5e50d1b0dbd Mon Sep 17 00:00:00 2001 From: Darren Colby Date: Tue, 2 Jul 2024 13:30:28 -0500 Subject: [PATCH] Shuffled data in DML, DRE, and RLearner constructors --- src/estimators.jl | 4 ++ src/metalearners.jl | 11 ++++- src/utilities.jl | 6 +-- testing.ipynb | 114 +++++++++++++++++++++++++++++++++----------- 4 files changed, 102 insertions(+), 33 deletions(-) diff --git a/src/estimators.jl b/src/estimators.jl index 76c205c..1f9b1bc 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -236,6 +236,10 @@ function DoubleMachineLearning( ) # Convert to arrays X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1] + + # Shuffle data with random indices + indices = shuffle(1:length(Y)) + X, T, Y = X[indices, :], T[indices], Y[indices] task = var_type(Y) isa Binary ? "classification" : "regression" diff --git a/src/metalearners.jl b/src/metalearners.jl index 2bed75a..f5eeeb5 100644 --- a/src/metalearners.jl +++ b/src/metalearners.jl @@ -296,6 +296,10 @@ function RLearner( # Convert to arrays X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1] + # Shuffle data with random indices + indices = shuffle(1:length(Y)) + X, T, Y = X[indices, :], T[indices], Y[indices] + task = var_type(Y) isa Binary ? "classification" : "regression" return RLearner( @@ -371,11 +375,14 @@ function DoublyRobustLearner( num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), num_neurons::Integer=round(Int, log10(size(X, 1)) * size(X, 2)), - folds::Integer=5, ) # Convert to arrays X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1] + # Shuffle data with random indices + indices = shuffle(1:length(Y)) + X, T, Y = X[indices, :], T[indices], Y[indices] + task = var_type(Y) isa Binary ? "classification" : "regression" return DoublyRobustLearner( @@ -391,7 +398,7 @@ function DoublyRobustLearner( num_feats, num_neurons, fill(NaN, size(T, 1)), - folds, + 2, ) end diff --git a/src/utilities.jl b/src/utilities.jl index 84d34a2..3c44495 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,3 +1,5 @@ +using Random: shuffle + """Abstract type used to dispatch risk_ratio on nonbinary treatments""" abstract type Nonbinary end @@ -185,9 +187,7 @@ function generate_folds(X, T, Y, folds) msg = """the number of folds must be less than the number of observations""" n = length(Y) - if folds >= n - throw(ArgumentError(msg)) - end + if folds >= n throw(ArgumentError(msg))end x_folds = Array{Array{Float64, 2}}(undef, folds) t_folds = Array{Array{Float64, 1}}(undef, folds) diff --git a/testing.ipynb b/testing.ipynb index f1ab913..5da93ef 100644 --- a/testing.ipynb +++ b/testing.ipynb @@ -134,13 +134,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DoubleMachineLearning([0.15384615384615385 0.1258211589371507 … 0.0 1.0; 0.6923076923076923 0.1441562898323365 … 0.0 1.0; … ; 0.41025641025641024 0.24039121482498285 … 0.0 1.0; 0.07692307692307693 0.11789145994705363 … 0.0 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-3300.0, 61010.0, 8849.0, -6013.0, -2375.0, -11000.0, -16901.0, 1000.0, 0.0, 6400.0 … -1436.0, 4500.0, 34739.0, -750.0, 40000.0, 172.0, 836.0, 6150.0, 14499.0, -5400.0], \"ATE\", false, \"regression\", CausalELM.relu, 9915, 100, 6, 24, NaN, 5)" + "DoubleMachineLearning([0.46153846153846156 0.21734974017060496 … 0.0 1.0; 0.5897435897435898 0.05495636827139916 … 0.0 0.0; … ; 0.02564102564102564 0.11648200804000393 … 0.0 1.0; 0.6410256410256411 0.22411510932444356 … 1.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [18800.0, 500.0, 5600.0, 62535.0, -5100.0, 9145.0, 25999.0, 0.0, 2150.0, 5000.0 … 189000.0, 14400.0, 240.0, 249.0, -928.0, 107750.0, 0.0, 114335.0, 10500.0, 8849.0], \"ATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 24, NaN, 5)" ] }, "metadata": {}, @@ -148,18 +148,18 @@ } ], "source": [ - "dml = DoubleMachineLearning(covariates, treatment, outcome)" + "dml = DoubleMachineLearning(covariates, treatment, outcome, activation=swish)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "RLearner([0.15384615384615385 0.1258211589371507 … 0.0 1.0; 0.6923076923076923 0.1441562898323365 … 0.0 1.0; … ; 0.41025641025641024 0.24039121482498285 … 0.0 1.0; 0.07692307692307693 0.11789145994705363 … 0.0 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-3300.0, 61010.0, 8849.0, -6013.0, -2375.0, -11000.0, -16901.0, 1000.0, 0.0, 6400.0 … -1436.0, 4500.0, 34739.0, -750.0, 40000.0, 172.0, 836.0, 6150.0, 14499.0, -5400.0], \"CATE\", false, \"regression\", CausalELM.relu, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)" + "RLearner([0.1282051282051282 0.11108932248259633 … 0.0 1.0; 0.6923076923076923 0.1186881066771252 … 0.0 1.0; … ; 0.20512820512820512 0.07500735366212374 … 0.0 1.0; 0.41025641025641024 0.11607755662319835 … 0.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-8900.0, -4800.0, 27500.0, -1650.0, -2000.0, 30740.0, 2859.0, -2150.0, 0.0, 11599.0 … 43599.0, -7200.0, 23309.0, 8774.0, 6500.0, -400.0, 22700.0, 7399.0, -5400.0, 1499.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)" ] }, "metadata": {}, @@ -167,18 +167,37 @@ } ], "source": [ - "r_learner = RLearner(covariates, treatment, outcome)" + "r_learner = RLearner(covariates, treatment, outcome, activation=swish)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "8823.500636214852" + "DoublyRobustLearner([0.6410256410256411 0.1558486126090793 … 0.0 1.0; 0.23076923076923078 0.06633003235611334 … 0.0 0.0; … ; 0.6153846153846154 0.06843808216491813 … 0.0 0.0; 0.7435897435897436 0.2292994411216786 … 0.0 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 … 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [100.0, 0.0, 14350.0, 4600.0, 84248.0, -1800.0, 1020.0, 2280.0, 14699.0, 881.0 … 367.0, -5600.0, -5400.0, 5674.0, 12211.0, 32500.0, 1152.0, 2182.0, 0.0, 330.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dre = DoublyRobustLearner(covariates, treatment, outcome, activation=swish)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8804.269472283213" ] }, "metadata": {}, @@ -191,33 +210,33 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9915-element Vector{Float64}:\n", - " 4085.5839404080925\n", - " 15773.51315113084\n", - " 38901.80802040522\n", - " 3825.3848781869037\n", - " 11964.765726429632\n", - " 26765.991729444253\n", - " 16975.200557225948\n", - " 7452.263104809677\n", - " 1115.323329175054\n", - " 12363.569530065344\n", + " 1033.275328379404\n", + " 3897.6188907530145\n", + " 27094.516749605616\n", + " 8327.283149032586\n", + " 6781.702531736929\n", + " 50200.72898282418\n", + " 618.590315821573\n", + " 6647.26749174192\n", + " 4325.783318029439\n", + " 16617.629336705013\n", " ⋮\n", - " 11433.6140084084\n", - " 4800.764220118784\n", - " 2925.4867379282705\n", - " 39714.813007228164\n", - " 1647.2470272172372\n", - " 10061.73821939839\n", - " 14687.816324667367\n", - " 17992.791169984106\n", - " 434.7500362608628" + " 25103.616301146572\n", + " 40417.24987461999\n", + " 6976.012498684692\n", + " 8869.662932387795\n", + " -1030.3323016387612\n", + " 4912.327776140574\n", + " 2840.9932292653525\n", + " 3323.126233753097\n", + " 21356.54170795394" ] }, "metadata": {}, @@ -228,6 +247,45 @@ "estimate_causal_effect!(r_learner)" ] }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9915-element Vector{Float64}:\n", + " 8651.259686332123\n", + " 2763.6426805062965\n", + " 4281.08620983512\n", + " 6996.106017505121\n", + " 37295.1224689869\n", + " 3425.2628336886887\n", + " 7259.653364085303\n", + " 3931.840707261489\n", + " 3390.6489181977217\n", + " 396.19186564028234\n", + " ⋮\n", + " 13778.740930336877\n", + " 13824.272936865971\n", + " 770.8718719469387\n", + " 5661.227928432385\n", + " 10218.778717409776\n", + " 3707.70741363045\n", + " 2089.690748271022\n", + " 3767.843767168565\n", + " 17841.535784697724" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "estimate_causal_effect!(dre)" + ] + }, { "cell_type": "code", "execution_count": 9,