Welcome to TALENT, a benchmark with a comprehensive machine learning toolbox designed to enhance model performance on tabular data. TALENT integrates advanced deep learning models, classical algorithms, and efficient hyperparameter tuning, offering robust preprocessing capabilities to optimize learning from tabular datasets. The toolbox is user-friendly and adaptable, catering to both novice and expert data scientists.
TALENT offers the following advantages:
- Diverse Methods: Includes various classical methods, tree-based methods, and the latest popular deep learning methods.
- Extensive Dataset Collection: Equipped with 300 datasets, covering a wide range of task types, size distributions, and dataset domains.
- Customizability: Easily allows the addition of datasets and methods.
- Versatile Support: Supports diverse normalization, encoding, and metrics.
If you use any content of this repo for your work, please cite the following bib entries:
@article{ye2024closerlookdeeplearning,
title={A Closer Look at Deep Learning on Tabular Data},
author={Han-Jia Ye and Si-Yang Liu and Hao-Run Cai and Qi-Le Zhou and De-Chuan Zhan},
journal={arXiv preprint arXiv:2407.00956},
year={2024}
}
@article{liu2024talenttabularanalyticslearning,
title={TALENT: A Tabular Analytics and Learning Toolbox},
author={Si-Yang Liu and Hao-Run Cai and Qi-Le Zhou and Han-Jia Ye},
journal={arXiv preprint arXiv:2407.04057},
year={2024}
}
- [2024-08]🌟 Add GRANDE (ICLR 2024).
- [2024-08]🌟 Add Excelformer (KDD 2024).
- [2024-08]🌟 Add MLP_PLR (NeurIPS 2022).
- [2024-07]🌟 Add RealMLP.
- [2024-07]🌟 Add ProtoGate (ICML 2024).
- [2024-07]🌟 Add BiSHop (ICML 2024).
- [2024-06]🌟 Check out our new baseline ModernNCA, inspired by traditional Neighbor Component Analysis, which outperforms both tree-based and other deep tabular models, while also reducing training time and model size!
- [2024-06]🌟 Check out our benchmark paper about tabular data, which provides comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner!
TALENT integrates an extensive array of 20+ deep learning architectures for tabular data, including but not limited to:
- MLP: A multi-layer neural network, which is implemented according to RTDL.
- ResNet: A DNN that uses skip connections across many layers, which is implemented according to RTDL.
- SNN: An MLP-like architecture utilizing the SELU activation, which facilitates the training of deeper neural networks.
- DANets: A neural network designed to enhance tabular data processing by grouping correlated features and reducing computational complexity.
- TabCaps: A capsule network that encapsulates all feature values of a record into vectorial features.
- DCNv2: Consists of an MLP-like module combined with a feature crossing module, which includes both linear layers and multiplications.
- NODE: A tree-mimic method that generalizes oblivious decision trees, combining gradient-based optimization with hierarchical representation learning.
- GrowNet: A gradient boosting framework that uses shallow neural networks as weak learners.
- TabNet: A tree-mimic method using sequential attention for feature selection, offering interpretability and self-supervised learning capabilities.
- TabR: A deep learning model that integrates a KNN component to enhance tabular data predictions through an efficient attention-like mechanism.
- ModernNCA: A deep tabular model inspired by traditional Neighbor Component Analysis, which makes predictions based on the relationships with neighbors in a learned embedding space.
- DNNR: Enhances KNN by using local gradients and Taylor approximations for more accurate and interpretable predictions.
- AutoInt: A token-based method that uses a multi-head self-attentive neural network to automatically learn high-order feature interactions.
- Saint: A token-based method that leverages row and column attention mechanisms for tabular data.
- TabTransformer: A token-based method that enhances tabular data modeling by transforming categorical features into contextual embeddings.
- FT-Transformer: A token-based method which transforms features to embeddings and applies a series of attention-based transformations to the embeddings.
- TANGOS: A regularization-based method for tabular data that uses gradient attributions to encourage neuron specialization and orthogonalization.
- SwitchTab: A self-supervised method tailored for tabular data that improves representation learning through an asymmetric encoder-decoder framework. Following the original paper, our toolkit uses a supervised learning form, optimizing both reconstruction and supervised loss in each epoch.
- PTaRL: A regularization-based framework that enhances prediction by constructing and projecting into a prototype-based space.
- TabPFN: A general model which involves the use of pre-trained deep neural networks that can be directly applied to any tabular task.
- HyperFast: A meta-trained hypernetwork that generates task-specific neural networks for instant classification of tabular data.
- TabPTM: A general method for tabular data that standardizes heterogeneous datasets using meta-representations, allowing a pre-trained model to generalize to unseen datasets without additional training.
- BiSHop: An end-to-end framework for deep tabular learning which leverages a sparse Hopfield model with adaptable sparsity, enhanced by column-wise and row-wise modules.
- ProtoGate: A prototype-based model for feature selection in HDLSS biomedical data that adapts global and local feature selection to enhance prediction accuracy and interpretability, addressing co-adaptation issues through a non-parametric prototype-based mechanism.
- RealMLP: An improved multilayer perceptron (MLP).
- MLP_PLR: An improved multilayer perceptron (MLP), which utilizes periodic activations.
- Excelformer: A deep learning model for tabular data prediction, featuring a semi-permeable attention module to address rotational invariance, tailored data augmentation, and an attentive feedforward network, making it a reliable solution across diverse datasets.
- GRANDE: A tree-mimic method for learning hard, axis-aligned decision tree ensembles using end-to-end gradient descent.
Clone this GitHub repository:
git clone https://github.com/qile2000/LAMDA-TALENT
cd LAMDA-TALENT/LAMDA-TALENT
-
Edit the
configs/default/[MODEL_NAME].json
andconfig/opt_space/[MODEL_NAME].json
for global settings and hyperparameters. -
Run:
python train_model_deep.py --model_type MODEL_NAME
for deep methods, or:
python train_model_classical.py --model_type MODEL_NAME
for classical methods.
For methods like the MLP class that only need to design the model, you only need to:
- Add the model class in
model/models
. - Inherit from
model/methods/base.py
and override theconstruct_model()
method in the new class. - Add the method name in the
get_method
function inmodel/utils.py
. - Add the parameter settings for the new method in
configs/default/[MODEL_NAME].json
andconfigs/opt_space/[MODEL_NAME].json
.
For other methods that require changing the training process, partially override functions based on model/methods/base.py
. For details, refer to the implementation of other methods in model/methods/
.
-
If you want to use TabR, you have to manually install faiss, which is only available on conda:
conda install faiss-gpu -c pytorch
Datasets are available at Google Drive.
Datasets are placed in the project's current directory, corresponding to the file name specified by args.dataset_path
. For instance, if the project is LAMDA-TALENT
, the data should be placed in LAMDA-TALENT/args.dataset_path/args.dataset
.
Each dataset folder args.dataset
consists of:
-
Numeric features:
N_train/val/test.npy
(can be omitted if there are no numeric features) -
Categorical features:
C_train/val/test.npy
(can be omitted if there are no categorical features) -
Labels:
y_train/val/test.npy
-
info.json
, which must include the following three contents (task_type can be "regression", "multiclass" or "binclass"):{ "task_type": "regression", "n_num_features": 10, "n_cat_features": 10 }
We provide comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner in the Figure. Three tabular prediction tasks, namely, binary classification, multi-class classification, and regression, are considered, and each subfigure represents a different task type.
We use Accuracy
and RMSE
as the metrics for classification tasks and regression tasks, respectively. To calibrate the metrics, we choose the average performance rank to compare all methods, where a lower rank indicates better performance, following Sheskin (2003). Efficiency is calculated by the average training time in seconds, with lower values denoting better time efficiency. The model size is visually indicated by the radius of the circles, offering a quick glance at the trade-off between model complexity and performance.
The classical method SVM
provided in TALENT is a LinearSVM
to ensure faster training. We also consider the Dummy
baseline, which outputs the label of the major class and the average labels for classification and regression tasks, respectively.
From the comparison, we observe that CatBoost achieves the best average rank in most classification and regression tasks. Among all deep tabular methods, ModernNCA performs the best in most cases while maintaining an acceptable training cost. These results highlight the effectiveness of CatBoost and ModernNCA in handling various tabular prediction tasks, making them suitable choices for practitioners seeking high performance and efficiency.
These visualizations serve as an effective tool for quickly and fairly assessing the strengths and weaknesses of various tabular methods across different task types, enabling researchers and practitioners to make informed decisions when selecting suitable modeling techniques for their specific needs.
We thank the following repos for providing helpful components/functions in our work:
- Rtdl-revisiting-models
- Rtdl-num-embeddings
- Tabular-dl-tabr
- DANet
- TabCaps
- DNNR
- PTaRL
- Saint
- SwitchTab
- TabNet
- TabPFN
- Tabtransformer-pytorch
- TANGOS
- GrowNet
- HyperFast
- BiSHop
- ProtoGate
- Pytabkit
- Excelformer
- GRANDE
If there are any questions, please feel free to propose new features by opening an issue or contact the author: Siyang Liu (liusy@lamda.nju.edu.cn) and Haorun Cai (caihr@smail.nju.edu.cn) and Qile Zhou (zhouql@lamda.nju.edu.cn) and Han-Jia Ye (yehj@lamda.nju.edu.cn). Enjoy the code.
Thanks LAMDA-PILOT and LAMDA-ZhiJian for the template.