This package implements DDLK, a method for variable selection with explicit control of the false discovery rate. Install with:
pip install ddlk
Suppose you have a set of features and a response. DDLK identifies the features most predictive of the response at a pre-specified false discovery rate (FDR) threshold. For example, if you choose an FDR of 20%, DDLK can guarantee that no more than 20% of the selected features will be unimportant. To learn more about how it works, check out our paper.
Variable selection with DDLK involves three stages:
- Fit a joint distribution to model features
- Fit a knockoff generator
- Sample knockoffs and apply knockoff filter to select variables at a pre-specified FDR
To see a complete working example, check our synthetic data example, used to generate the gif above. Below is an exceprt of how to run DDLK.
This implementation of DDLK uses the fast and easy PyTorch Lightning framework to fit q_joint
:
# initialize data
x, y = ...
# put your data in standard PyTorch format
trainloader = ...
# initialize joint distribution model with mean and std of data
((X_mu, ), (X_sigma, )) = utils.get_two_moments(trainloader)
hparams = argparse.Namespace(X_mu=X_mu, X_sigma=X_sigma)
q_joint = mdn.MDNJoint(hparams)
# create and fit a PyTorch Lightning trainer
trainer = pl.Trainer()
trainer.fit(q_joint, train_dataloader=trainloader)
# initialize and fit a DDLK knockoff generator
q_knockoff = ddlk.DDLK(hparams, q_joint=q_joint)
trainer = pl.Trainer()
trainer.fit(q_knockoff, train_dataloader=trainloader)
Using the knockoff generator, we sample knockoffs, and run a Holdout Randomization Test:
xTr_tilde = q_knockoff.sample(xTr)
knockoff_test = hrt.HRT_Knockoffs()
knockoff_test.fit(xTr, yTr, xTr_tilde)
If you use this code, please cite the following paper (available here):
Deep Direct Likelihood Knockoffs
M. Sudarshan, W. Tansey, R. Ranganath
arXiv preprint arXiv:2007.15835
Bibtex entry:
@misc{sudarshan2020deep,
title={Deep Direct Likelihood Knockoffs},
author={Mukund Sudarshan and Wesley Tansey and Rajesh Ranganath},
year={2020},
eprint={2007.15835},
archivePrefix={arXiv},
primaryClass={stat.ML}
}