From 84008b54b0249674e0bb95c28a5218a5a5f3c759 Mon Sep 17 00:00:00 2001 From: Darren Colby Date: Tue, 2 Jul 2024 19:09:59 -0500 Subject: [PATCH] Made swish the default activation function --- docs/src/release_notes.md | 1 + src/estimators.jl | 14 ++-- src/metalearners.jl | 20 ++--- testing.ipynb | 171 ++++++++++++++++++++++++-------------- 4 files changed, 128 insertions(+), 78 deletions(-) diff --git a/docs/src/release_notes.md b/docs/src/release_notes.md index 3197ca9..bd8f9cb 100644 --- a/docs/src/release_notes.md +++ b/docs/src/release_notes.md @@ -9,6 +9,7 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/ * Calculate probabilities as the average label predicted by the ensemble instead of clipping [#71](https://github.com/dscolby/CausalELM.jl/issues/71) * Made calculation of p-values and standard errors optional and not executed by default in summarize methods [#65](https://github.com/dscolby/CausalELM.jl/issues/65) * Removed redundant W argument for double machine learning, R-learning, and doubly robust estimation [#68](https://github.com/dscolby/CausalELM.jl/issues/68) +* Use swish as the default activation function [#72](https://github.com/dscolby/CausalELM.jl/issues/72) ### Fixed * Applying the weight trick for R-learning [#70](https://github.com/dscolby/CausalELM.jl/issues/70) diff --git a/src/estimators.jl b/src/estimators.jl index 1f9b1bc..22750a1 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -13,7 +13,7 @@ Initialize an interrupted time series estimator. - `Y₁::Any`: array or DataFrame of outcomes from the post-treatment period. # Keywords -- `activation::Function=relu`: activation function to use. +- `activation::Function=swish`: activation function to use. - `sample_size::Integer=size(X₀, 1)`: number of bootstrapped samples for the extreme learner. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -56,7 +56,7 @@ function InterruptedTimeSeries( Y₀, X₁, Y₁; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X₀, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X₀, 2))), @@ -102,7 +102,7 @@ Initialize a G-Computation estimator. # Keywords - `quantity_of_interest::String`: ATE for average treatment effect or ATT for average treatment effect on the treated. -- `activation::Function=relu`: activation function to use. +- `activation::Function=swish`: activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for the extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -144,7 +144,7 @@ mutable struct GComputation <: CausalEstimator T, Y; quantity_of_interest::String="ATE", - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -188,7 +188,7 @@ Initialize a double machine learning estimator with cross fitting. - `Y::Any`: array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: activation function to use. +- `activation::Function=swish`: activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for teh extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -227,7 +227,7 @@ function DoubleMachineLearning( X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -236,7 +236,7 @@ function DoubleMachineLearning( ) # Convert to arrays X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1] - + # Shuffle data with random indices indices = shuffle(1:length(Y)) X, T, Y = X[indices, :], T[indices], Y[indices] diff --git a/src/metalearners.jl b/src/metalearners.jl index f5eeeb5..7de65bc 100644 --- a/src/metalearners.jl +++ b/src/metalearners.jl @@ -12,7 +12,7 @@ Initialize a S-Learner. - `Y::Any`: an array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: the activation function to use. +- `activation::Function=swish`: the activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -51,7 +51,7 @@ mutable struct SLearner <: Metalearner X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -91,7 +91,7 @@ Initialize a T-Learner. - `Y::Any`: an array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: the activation function to use. +- `activation::Function=swish`: the activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -130,7 +130,7 @@ mutable struct TLearner <: Metalearner X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -169,7 +169,7 @@ Initialize an X-Learner. - `Y::Any`: an array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: the activation function to use. +- `activation::Function=swish`: the activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -209,7 +209,7 @@ mutable struct XLearner <: Metalearner X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -248,7 +248,7 @@ Initialize an R-Learner. - `Y::Any`: an array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: the activation function to use. +- `activation::Function=swish`: the activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -285,7 +285,7 @@ function RLearner( X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), @@ -330,7 +330,7 @@ Initialize a doubly robust CATE estimator. - `Y::Any`: an array or DataFrame of outcomes. # Keywords -- `activation::Function=relu`: the activation function to use. +- `activation::Function=swish`: the activation function to use. - `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme learners. - `num_machines::Integer=100`: number of extreme learning machines for the ensemble. @@ -370,7 +370,7 @@ function DoublyRobustLearner( X, T, Y; - activation::Function=relu, + activation::Function=swish, sample_size::Integer=size(X, 1), num_machines::Integer=100, num_feats::Integer=Int(round(0.75 * size(X, 2))), diff --git a/testing.ipynb b/testing.ipynb index 5da93ef..4adbfbd 100644 --- a/testing.ipynb +++ b/testing.ipynb @@ -134,13 +134,13 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DoubleMachineLearning([0.46153846153846156 0.21734974017060496 … 0.0 1.0; 0.5897435897435898 0.05495636827139916 … 0.0 0.0; … ; 0.02564102564102564 0.11648200804000393 … 0.0 1.0; 0.6410256410256411 0.22411510932444356 … 1.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [18800.0, 500.0, 5600.0, 62535.0, -5100.0, 9145.0, 25999.0, 0.0, 2150.0, 5000.0 … 189000.0, 14400.0, 240.0, 249.0, -928.0, 107750.0, 0.0, 114335.0, 10500.0, 8849.0], \"ATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 24, NaN, 5)" + "DoubleMachineLearning([0.46153846153846156 0.33966565349544076 … 1.0 1.0; 0.10256410256410256 0.08505735856456516 … 0.0 0.0; … ; 0.6923076923076923 0.042308069418570446 … 0.0 0.0; 0.10256410256410256 0.17147514462202176 … 1.0 1.0], [1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0 … 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], [14248.0, 2300.0, 0.0, 33748.0, -800.0, 25398.0, -1200.0, 120000.0, 15300.0, 100.0 … 60201.0, 51987.0, 9249.0, 6420.0, 3200.0, 99300.0, 19599.0, 8030.0, 4190.0, 8400.0], \"ATE\", false, \"regression\", CausalELM.swish, 9915, 50, 6, 24, NaN, 5)" ] }, "metadata": {}, @@ -148,18 +148,18 @@ } ], "source": [ - "dml = DoubleMachineLearning(covariates, treatment, outcome, activation=swish)" + "dml = DoubleMachineLearning(covariates, treatment, outcome, num_machines=50)" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "RLearner([0.1282051282051282 0.11108932248259633 … 0.0 1.0; 0.6923076923076923 0.1186881066771252 … 0.0 1.0; … ; 0.20512820512820512 0.07500735366212374 … 0.0 1.0; 0.41025641025641024 0.11607755662319835 … 0.0 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0 … 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-8900.0, -4800.0, 27500.0, -1650.0, -2000.0, 30740.0, 2859.0, -2150.0, 0.0, 11599.0 … 43599.0, -7200.0, 23309.0, 8774.0, 6500.0, -400.0, 22700.0, 7399.0, -5400.0, 1499.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)" + "RLearner([0.3333333333333333 0.09809785273065987 … 1.0 0.0; 0.1282051282051282 0.08584174919109716 … 0.0 1.0; … ; 0.48717948717948717 0.6506030002941465 … 1.0 1.0; 0.5128205128205128 0.07530150014707324 … 0.0 1.0], [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0 … 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], [11000.0, 50.0, 157973.0, 100.0, -3700.0, 26000.0, 44.0, 32000.0, -6705.0, 10500.0 … 999.0, -18000.0, 46099.0, 920.0, -19950.0, 300.0, 11750.0, 182500.0, 47000.0, 499.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)" ] }, "metadata": {}, @@ -167,18 +167,18 @@ } ], "source": [ - "r_learner = RLearner(covariates, treatment, outcome, activation=swish)" + "r_learner = RLearner(covariates, treatment, outcome)" ] }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DoublyRobustLearner([0.6410256410256411 0.1558486126090793 … 0.0 1.0; 0.23076923076923078 0.06633003235611334 … 0.0 0.0; … ; 0.6153846153846154 0.06843808216491813 … 0.0 0.0; 0.7435897435897436 0.2292994411216786 … 0.0 1.0], [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 … 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0], [100.0, 0.0, 14350.0, 4600.0, 84248.0, -1800.0, 1020.0, 2280.0, 14699.0, 881.0 … 367.0, -5600.0, -5400.0, 5674.0, 12211.0, 32500.0, 1152.0, 2182.0, 0.0, 330.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)" + "DoublyRobustLearner([0.8974358974358975 0.2030100990293166 … 0.0 1.0; 0.5897435897435898 0.2634326894793607 … 0.0 1.0; … ; 0.0 0.19087655652514954 … 0.0 0.0; 0.3333333333333333 0.32516668300813806 … 1.0 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], [41649.0, 9000.0, 0.0, 16350.0, 6000.0, 700.0, 13059.0, 5930.0, 23397.0, 1323.0 … 24500.0, 8050.0, -11000.0, 35499.0, -2854.0, 197590.0, -1400.0, 7700.0, 12000.0, 42050.0], \"CATE\", false, \"regression\", CausalELM.swish, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 2)" ] }, "metadata": {}, @@ -186,18 +186,18 @@ } ], "source": [ - "dre = DoublyRobustLearner(covariates, treatment, outcome, activation=swish)" + "dre = DoublyRobustLearner(covariates, treatment, outcome)" ] }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "8804.269472283213" + "8868.651114858334" ] }, "metadata": {}, @@ -210,33 +210,33 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9915-element Vector{Float64}:\n", - " 1033.275328379404\n", - " 3897.6188907530145\n", - " 27094.516749605616\n", - " 8327.283149032586\n", - " 6781.702531736929\n", - " 50200.72898282418\n", - " 618.590315821573\n", - " 6647.26749174192\n", - " 4325.783318029439\n", - " 16617.629336705013\n", - " ⋮\n", - " 25103.616301146572\n", - " 40417.24987461999\n", - " 6976.012498684692\n", - " 8869.662932387795\n", - " -1030.3323016387612\n", - " 4912.327776140574\n", - " 2840.9932292653525\n", - " 3323.126233753097\n", - " 21356.54170795394" + " 7969.024541481493\n", + " 2551.07486621794\n", + " 48185.11603976369\n", + " 6562.417861484062\n", + " 12324.513387722585\n", + " 91413.60918565083\n", + " 103742.23330057286\n", + " 13234.161144429849\n", + " 16753.004994337723\n", + " 6429.458448880052\n", + " ⋮\n", + " 2331.601849423459\n", + " 50477.892771963685\n", + " 19942.337555990453\n", + " 12658.185171498155\n", + " -442.6517574940871\n", + " 72754.7346983037\n", + " 42410.30074258264\n", + " 64041.35045474993\n", + " 1374.0969545336325" ] }, "metadata": {}, @@ -249,33 +249,33 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9915-element Vector{Float64}:\n", - " 8651.259686332123\n", - " 2763.6426805062965\n", - " 4281.08620983512\n", - " 6996.106017505121\n", - " 37295.1224689869\n", - " 3425.2628336886887\n", - " 7259.653364085303\n", - " 3931.840707261489\n", - " 3390.6489181977217\n", - " 396.19186564028234\n", + " 13549.633020274861\n", + " 20881.59369086071\n", + " 1879.2141524564345\n", + " 4752.192233979611\n", + " 9972.464441326127\n", + " 5368.174090907391\n", + " 8080.56176700674\n", + " 11685.092957657413\n", + " -1689.8961993687453\n", + " 4964.903827056494\n", " ⋮\n", - " 13778.740930336877\n", - " 13824.272936865971\n", - " 770.8718719469387\n", - " 5661.227928432385\n", - " 10218.778717409776\n", - " 3707.70741363045\n", - " 2089.690748271022\n", - " 3767.843767168565\n", - " 17841.535784697724" + " 12745.035572594325\n", + " 13779.898140138454\n", + " 15285.34382394138\n", + " 7686.997478984806\n", + " 10874.155814573602\n", + " 9104.438679085306\n", + " 5974.4691837941145\n", + " -39.615643944068324\n", + " -9482.093434774426" ] }, "metadata": {}, @@ -288,18 +288,18 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dict{Any, Any} with 11 entries:\n", - " \"Activation Function\" => relu\n", + " \"Activation Function\" => swish\n", " \"Quantity of Interest\" => \"ATE\"\n", " \"Sample Size\" => 9915\n", " \"Number of Machines\" => 100\n", - " \"Causal Effect\" => 8823.5\n", + " \"Causal Effect\" => 8701.76\n", " \"Number of Neurons\" => 24\n", " \"Task\" => \"regression\"\n", " \"Time Series/Panel Data\" => false\n", @@ -318,18 +318,18 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dict{Any, Any} with 11 entries:\n", - " \"Activation Function\" => relu\n", + " \"Activation Function\" => swish\n", " \"Quantity of Interest\" => \"CATE\"\n", " \"Sample Size\" => 9915\n", " \"Number of Machines\" => 100\n", - " \"Causal Effect\" => [4085.58, 15773.5, 38901.8, 3825.38, 11964.8, 267…\n", + " \"Causal Effect\" => [7969.02, 2551.07, 48185.1, 6562.42, 12324.5, 914…\n", " \"Number of Neurons\" => 32\n", " \"Task\" => \"regression\"\n", " \"Time Series/Panel Data\" => false\n", @@ -354,7 +354,18 @@ { "data": { "text/plain": [ - "(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => -8079.331571957283, \"0.075 Standard Deviations from Observed Outcomes\" => -6089.203934396697, \"0.025 Standard Deviations from Observed Outcomes\" => -7522.457852582857, \"0.05 Standard Deviations from Observed Outcomes\" => -12933.100480526482), 2.6894381997142496, Matrix{Float64}(undef, 0, 9))" + "Dict{Any, Any} with 11 entries:\n", + " \"Activation Function\" => swish\n", + " \"Quantity of Interest\" => \"CATE\"\n", + " \"Sample Size\" => 9915\n", + " \"Number of Machines\" => 100\n", + " \"Causal Effect\" => [13549.6, 20881.6, 1879.21, 4752.19, 9972.46, 536…\n", + " \"Number of Neurons\" => 32\n", + " \"Task\" => \"regression\"\n", + " \"Time Series/Panel Data\" => false\n", + " \"Standard Error\" => NaN\n", + " \"p-value\" => NaN\n", + " \"Number of Features\" => 6" ] }, "metadata": {}, @@ -362,7 +373,7 @@ } ], "source": [ - "validate(dml)" + "summarise(dre)" ] }, { @@ -373,7 +384,26 @@ { "data": { "text/plain": [ - "(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => 155340.94980401796, \"0.075 Standard Deviations from Observed Outcomes\" => 559571.3301919985, \"0.025 Standard Deviations from Observed Outcomes\" => 274961.5431470514, \"0.05 Standard Deviations from Observed Outcomes\" => 345062.1310616215), 2.8689322412325833, Matrix{Float64}(undef, 0, 9))" + "(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => -1974.20426849962, \"0.075 Standard Deviations from Observed Outcomes\" => -549.8183509860896, \"0.025 Standard Deviations from Observed Outcomes\" => -4377.799707458391, \"0.05 Standard Deviations from Observed Outcomes\" => -2591.878868163885), 2.7460736072464016, Matrix{Float64}(undef, 0, 9))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "validate(dml)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => 248206.3227597734, \"0.075 Standard Deviations from Observed Outcomes\" => 404160.7518203919, \"0.025 Standard Deviations from Observed Outcomes\" => 322479.1870944485, \"0.05 Standard Deviations from Observed Outcomes\" => 155068.1882045497), 2.5694922346983624, Matrix{Float64}(undef, 0, 9))" ] }, "metadata": {}, @@ -383,6 +413,25 @@ "source": [ "validate(r_learner)" ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => 430.7554873780964, \"0.075 Standard Deviations from Observed Outcomes\" => -4156.750735846773, \"0.025 Standard Deviations from Observed Outcomes\" => -5301.764975883297, \"0.05 Standard Deviations from Observed Outcomes\" => -6012.136217190272), 2.5976021674608534, Matrix{Float64}(undef, 0, 9))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "validate(dre)" + ] } ], "metadata": {