Skip to content

Commit

Permalink
Fixed R-learning again
Browse files Browse the repository at this point in the history
  • Loading branch information
dscolby committed Jul 2, 2024
1 parent 4c45278 commit c97ea4e
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 13 deletions.
7 changes: 2 additions & 5 deletions src/metalearners.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra: Diagonal

"""Abstract type for metalearners"""
abstract type Metalearner end

Expand Down Expand Up @@ -516,9 +514,8 @@ function estimate_causal_effect!(R::RLearner)
end

# Using target transformation and the weight trick to minimize the causal loss
T̃², target = reduce(vcat, T̃).^2, reduce(vcat, T̃) ./ reduce(vcat, Ỹ)
W⁻⁵ᵉ⁻¹ = Diagonal(sqrt.(T̃²))
Xʷ, Yʷ = W⁻⁵ᵉ⁻¹ * R.X, W⁻⁵ᵉ⁻¹ * target
T̃², target = reduce(vcat, T̃).^2, reduce(vcat, Ỹ) ./ reduce(vcat, T̃)
Xʷ, Yʷ = R.X .* T̃², target .* T̃²

# Fit a weighted residual-on-residual model
final_model = ELMEnsemble(
Expand Down
123 changes: 115 additions & 8 deletions testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,84 @@
}
],
"source": [
"dr_learner = DoubleMachineLearning(covariates, treatment, outcome, num_feats=6)"
"dml = DoubleMachineLearning(covariates, treatment, outcome)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RLearner([0.15384615384615385 0.1258211589371507 … 0.0 1.0; 0.6923076923076923 0.1441562898323365 … 0.0 1.0; … ; 0.41025641025641024 0.24039121482498285 … 0.0 1.0; 0.07692307692307693 0.11789145994705363 … 0.0 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [-3300.0, 61010.0, 8849.0, -6013.0, -2375.0, -11000.0, -16901.0, 1000.0, 0.0, 6400.0 … -1436.0, 4500.0, 34739.0, -750.0, 40000.0, 172.0, 836.0, 6150.0, 14499.0, -5400.0], \"CATE\", false, \"regression\", CausalELM.relu, 9915, 100, 6, 32, [NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN … NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN, NaN], 5)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"r_learner = RLearner(covariates, treatment, outcome)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8823.500636214852"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"estimate_causal_effect!(dml)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"8667.309064475481"
"9915-element Vector{Float64}:\n",
" 4085.5839404080925\n",
" 15773.51315113084\n",
" 38901.80802040522\n",
" 3825.3848781869037\n",
" 11964.765726429632\n",
" 26765.991729444253\n",
" 16975.200557225948\n",
" 7452.263104809677\n",
" 1115.323329175054\n",
" 12363.569530065344\n",
"\n",
" 11433.6140084084\n",
" 4800.764220118784\n",
" 2925.4867379282705\n",
" 39714.813007228164\n",
" 1647.2470272172372\n",
" 10061.73821939839\n",
" 14687.816324667367\n",
" 17992.791169984106\n",
" 434.7500362608628"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"estimate_causal_effect!(dr_learner)"
"estimate_causal_effect!(r_learner)"
]
},
{
Expand All @@ -183,7 +241,7 @@
" \"Quantity of Interest\" => \"ATE\"\n",
" \"Sample Size\" => 9915\n",
" \"Number of Machines\" => 100\n",
" \"Causal Effect\" => 8806.5\n",
" \"Causal Effect\" => 8823.5\n",
" \"Number of Neurons\" => 24\n",
" \"Task\" => \"regression\"\n",
" \"Time Series/Panel Data\" => false\n",
Expand All @@ -197,7 +255,7 @@
}
],
"source": [
"summarize(dr_learner)"
"summarize(dml)"
]
},
{
Expand All @@ -208,15 +266,64 @@
{
"data": {
"text/plain": [
"(Dict(0.025 => -12979.904119051262, 0.075 => -12217.068316708708, 0.1 => -6143.33640640303, 0.05 => -9062.747974951273), 2.8344920146887382, Matrix{Float64}(undef, 0, 9))"
"Dict{Any, Any} with 11 entries:\n",
" \"Activation Function\" => relu\n",
" \"Quantity of Interest\" => \"CATE\"\n",
" \"Sample Size\" => 9915\n",
" \"Number of Machines\" => 100\n",
" \"Causal Effect\" => [4085.58, 15773.5, 38901.8, 3825.38, 11964.8, 267…\n",
" \"Number of Neurons\" => 32\n",
" \"Task\" => \"regression\"\n",
" \"Time Series/Panel Data\" => false\n",
" \"Standard Error\" => NaN\n",
" \"p-value\" => NaN\n",
" \"Number of Features\" => 6"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"summarize(r_learner)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => -8079.331571957283, \"0.075 Standard Deviations from Observed Outcomes\" => -6089.203934396697, \"0.025 Standard Deviations from Observed Outcomes\" => -7522.457852582857, \"0.05 Standard Deviations from Observed Outcomes\" => -12933.100480526482), 2.6894381997142496, Matrix{Float64}(undef, 0, 9))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"validate(dml)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dict(\"0.1 Standard Deviations from Observed Outcomes\" => 155340.94980401796, \"0.075 Standard Deviations from Observed Outcomes\" => 559571.3301919985, \"0.025 Standard Deviations from Observed Outcomes\" => 274961.5431470514, \"0.05 Standard Deviations from Observed Outcomes\" => 345062.1310616215), 2.8689322412325833, Matrix{Float64}(undef, 0, 9))"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"validate(dr_learner)"
"validate(r_learner)"
]
}
],
Expand Down

0 comments on commit c97ea4e

Please sign in to comment.