Skip to content

Commit

Permalink
Made swish the default activation function
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Jul 3, 2024
1 parent c6139e1 commit 84008b5
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 78 deletions.
1 change: 1 addition & 0 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ These release notes adhere to the [keep a changelog](https://keepachangelog.com/
* Calculate probabilities as the average label predicted by the ensemble instead of clipping [#71](https://github.com/dscolby/CausalELM.jl/issues/71)
* Made calculation of p-values and standard errors optional and not executed by default in summarize methods [#65](https://github.com/dscolby/CausalELM.jl/issues/65)
* Removed redundant W argument for double machine learning, R-learning, and doubly robust estimation [#68](https://github.com/dscolby/CausalELM.jl/issues/68)
* Use swish as the default activation function [#72](https://github.com/dscolby/CausalELM.jl/issues/72)
### Fixed
* Applying the weight trick for R-learning [#70](https://github.com/dscolby/CausalELM.jl/issues/70)

Expand Down
14 changes: 7 additions & 7 deletions src/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Initialize an interrupted time series estimator.
- `Y₁::Any`: array or DataFrame of outcomes from the post-treatment period.
# Keywords
- `activation::Function=relu`: activation function to use.
- `activation::Function=swish`: activation function to use.
- `sample_size::Integer=size(X₀, 1)`: number of bootstrapped samples for the extreme
learner.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -56,7 +56,7 @@ function InterruptedTimeSeries(
Y₀,
X₁,
Y₁;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X₀, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X₀, 2))),
Expand Down Expand Up @@ -102,7 +102,7 @@ Initialize a G-Computation estimator.
# Keywords
- `quantity_of_interest::String`: ATE for average treatment effect or ATT for average
treatment effect on the treated.
- `activation::Function=relu`: activation function to use.
- `activation::Function=swish`: activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for the extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -144,7 +144,7 @@ mutable struct GComputation <: CausalEstimator
T,
Y;
quantity_of_interest::String="ATE",
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down Expand Up @@ -188,7 +188,7 @@ Initialize a double machine learning estimator with cross fitting.
- `Y::Any`: array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: activation function to use.
- `activation::Function=swish`: activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for teh extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -227,7 +227,7 @@ function DoubleMachineLearning(
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand All @@ -236,7 +236,7 @@ function DoubleMachineLearning(
)
# Convert to arrays
X, T, Y = Matrix{Float64}(X), T[:, 1], Y[:, 1]

# Shuffle data with random indices
indices = shuffle(1:length(Y))
X, T, Y = X[indices, :], T[indices], Y[indices]
Expand Down
20 changes: 10 additions & 10 deletions src/metalearners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Initialize a S-Learner.
- `Y::Any`: an array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: the activation function to use.
- `activation::Function=swish`: the activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -51,7 +51,7 @@ mutable struct SLearner <: Metalearner
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down Expand Up @@ -91,7 +91,7 @@ Initialize a T-Learner.
- `Y::Any`: an array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: the activation function to use.
- `activation::Function=swish`: the activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -130,7 +130,7 @@ mutable struct TLearner <: Metalearner
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down Expand Up @@ -169,7 +169,7 @@ Initialize an X-Learner.
- `Y::Any`: an array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: the activation function to use.
- `activation::Function=swish`: the activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -209,7 +209,7 @@ mutable struct XLearner <: Metalearner
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down Expand Up @@ -248,7 +248,7 @@ Initialize an R-Learner.
- `Y::Any`: an array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: the activation function to use.
- `activation::Function=swish`: the activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -285,7 +285,7 @@ function RLearner(
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down Expand Up @@ -330,7 +330,7 @@ Initialize a doubly robust CATE estimator.
- `Y::Any`: an array or DataFrame of outcomes.
# Keywords
- `activation::Function=relu`: the activation function to use.
- `activation::Function=swish`: the activation function to use.
- `sample_size::Integer=size(X, 1)`: number of bootstrapped samples for eth extreme
learners.
- `num_machines::Integer=100`: number of extreme learning machines for the ensemble.
Expand Down Expand Up @@ -370,7 +370,7 @@ function DoublyRobustLearner(
X,
T,
Y;
activation::Function=relu,
activation::Function=swish,
sample_size::Integer=size(X, 1),
num_machines::Integer=100,
num_feats::Integer=Int(round(0.75 * size(X, 2))),
Expand Down
Loading

0 comments on commit 84008b5

Please sign in to comment.