This repository contains the implementation of a DAG-aware Transformer model for causal inference, as described in our paper DAG aware Transformer for Causal Effect Estimation. Our model incorporates causal structure into the attention mechanism, allowing for more accurate modeling of causal relationships in various estimation frameworks including G-formula, Inverse Probability Weighting (IPW), and Augmented Inverse Probability Weighting (AIPW).
- DAG-aware attention mechanism
- Support for multiple causal inference methods (G-formula, IPW, AIPW)
- Flexible architecture for both joint and separate training of propensity score and outcome models
- Extension to proximal causal inference
Our project is organized as follows:
.
├── README.md
├── config
│ ├── dag
│ └── train
├── data
│ ├── acic
│ └── lalonde
├── experiments
│ ├── results
│ └── tuning
├── requirements.txt
├── scripts
│ ├── myjob.sh
│ └── myjob_proximal.sh
├── src
│ ├── data
│ ├── dataset.py
│ ├── evaluate
│ ├── experiment.py
│ ├── experiment_proximal.py
│ ├── models
│ ├── train
│ ├── utils.py
│ ├── utils_proximal
│ └── visualization
└── tests
config/: Contains configuration files for DAG structures and training parameters.data/: Contains data loading and preprocessing scripts.experiments/: Holds experimental results.scripts/: Contains scripts for running the experiments.src/: The main source code directory.data/: Data loading and preprocessing modules.evaluate/: Evaluation metrics and functions.models/: DAG-aware Transformer model and baseline models along with their loss functions.train/: Programs to compute pseudo ATE/CATE (see descriptions in Hyper-parameter tuning section in our paper) and the computed values.utils/: Utility functions for data processing and model training.utils_proximal/: Utility functions for proximal inference.visualization/: Code for generating plots and visualizations.experiment.py: Main script for running experiments.experiment_proximal.py: Main script for running proximal inference experiments.
tests/: Unit tests for the project.
To install the required dependencies, run:
pip install -r requirements.txtWe evaluate our model on four datasets:
- Lalonde-CPS
- Lalonde-PSID
- ACIC
- Demand dataset (for proximal inference)
Data preprocessing scripts and instructions can be found in the data/ directory.
To reproduce the experiments for Lalonde-CPS, Lalonde-PSID and ACIC, run:
python3 src/experiment.py \
--config config/train/<DATA_NAME>/<DATA_NAME>_sample<SAMPLE_ID>.json \
--dag <DAG_TYPE> \
--estimator <ESTIMATOR_TYPE> \
--data_name <DATA_NAME>-
CONFIG_FILE: The configuration file for the experiment
- Location:
config/train/<DATA_NAME>/ - Naming Convention:
<DATA_NAME>_sample<SAMPLE_ID>.json - Examples:
acic_sample1.jsonlalonde_cps_sample2.jsonlalonde_psid_sample3.json
- Location:
-
DAG_TYPE: The type of Directed Acyclic Graph (DAG) to use
- Options:
dag_g_formuladag_ipwdag_aipw
- Options:
-
ESTIMATOR_TYPE: The type of estimator to use
- Options:
g-formulaipwaipw
- Options:
-
DATA_NAME: The name of the dataset
- Options:
lalonde_cpslalonde_psidacic
- Options:
-
SAMPLE_ID: The sample ID for the experiment
- A numeric value from 1 to 10 (e.g., 1, 2, 3, ...)
python3 src/experiment.py \
--config config/train/lalonde_cps/lalonde_cps_sample1.json \
--dag dag_ipw \
--estimator ipw \
--data_name lalonde-cpsTo get the result where you train outcome regression and propensity score models separately, you can run the following command:
- Get predictions for outcome regression (e.g. for ACIC):
python3 src/experiment.py \
--config config/train/acic/acic_sample1.json \
--dag dag_g_formula \
--estimator g-formula \
--data_name acic- Get predictions for propensity score (e.g. for ACIC):
python3 src/experiment.py \
--config config/train/acic/acic_sample1.json \
--dag dag_ipw \
--estimator ipw \
--data_name acic- Plug in the predicted values to AIPW estimator (e.g. for ACIC):
python3 src/evaluate/acic/evaluate_metrics.py \
--data_name acic \
--estimator aipw \
--sample_id 1
python3 src/experiment_proximal.py \
--dag <DAG_CONFIG_FILE> \
--config config/train/proximal/nmmr_<STATISTICS>_z_transformer_n<SAMPLE_SIZE>.json \
--results_dir <RESULTS_DIRECTORY> \
--sample_index <SAMPLE_INDEX>-
DAG_CONFIG_FILE: The configuration file for the Directed Acyclic Graph (DAG)
- Location:
config/dag/ - Example:
proximal_dag_z.json
- Location:
-
STATISTICS: The type of statistics used in proximal inference
- Options:
u(U-statistics) orv(V-statistics)
- Options:
-
SAMPLE_SIZE: The size of the sample used in the experiment
- Example values:
50000,100000, etc.
- Example values:
-
RESULTS_DIRECTORY: The directory where results will be stored
- Default:
experiments/results/proximal
- Default:
-
SAMPLE_INDEX: The index of the sample to use for the experiment (form 0 to 19)
- Example values:
0,1,2, etc.
- Example values:
python3 src/experiment_proximal.py \
--dag config/dag/proximal_dag_z.json \
--config config/train/proximal/nmmr_v_z_transformer_n50000.json \
--results_dir experiments/results/proximal \
--sample_index 1You can also run the experiment using the provided script scripts/myjob.sh for lalonde-cps, lalonde-acic and ACIC; and
scripts/myjob_proximal.sh for demand by modifying the parameters in the script.
If you use this code in your research, please cite our paper:
@misc{liu2024dagawaretransformercausaleffect,
title={DAG-aware Transformer for Causal Effect Estimation},
author={Manqing Liu and David R. Bellamy and Andrew L. Beam},
year={2024},
eprint={2410.10044},
archivePrefix={arXiv},
primaryClass={stat.ML},
url={https://arxiv.org/abs/2410.10044},
}This project is licensed under the MIT License. For the complete terms and conditions, refer to the LICENSE file or visit: https://opensource.org/licenses/MIT.
For any questions or concerns, please open an issue or contact Manqing Liu at manqingliu@g.harvard.edu.