This repository contains the code to reproduce the experiments carried out in JoLT: Joint Probabilistic Predictions on Tabular Data Using LLMs.
The code has been authored by: John Bronskill, Aliaksandra Shysheya, Shoaib Ahmed Siddiqui, and James Requeima.
This code requires the following:
- python 3.9 or greater
- PyTorch 2.6.0 or greater
- transformers 4.48.3 or greater
- accelerate 1.3.0 or greater
- jsonargparse 4.36.0 or greater
- numpy 2.0.2 or greater
- scikit-learn 1.6.1 or greater
- scipy 1.13.1 or greater
- pandas 2.2.3 or greater
We support a variety of LLMs through the Hugging Face transformer APIs. The code currently supports the following LLMs:
Adding a new LLM that supports the hugging face APIs is not difficult, just modify
- Clone or download this repository.
- Install the python libraries listed under dependencies.
- Change directory to the root directory of this repo.
- On linux run:
- On Windows run:
From the root directory of the repo, run any of the commands below.
python ./experiments/classification/ --llm_type <LLM Type> --batch_size <value>
python ./experiments/multi_target_prediction/ --llm_type <LLM Type> --batch_size <value>
python --experiment_name medals --data_path data/medals.csv --llm_type <LLM Type> --output_dir experiments/multi_target_prediction/output --num_samples 1 --batch_size 5 --mode sample_logpy --num_decimal_places_x 0 --num_decimal_places_y 0 --y_column_types numerical numerical --y_column_names 'Silver Medal Count' 'Gold Medal Count' --max_generated_length 25 --header_option headers_as_item_prefix --top_k 1 --prefix 'Each example contains five columns: Olympic Year, Country, Bronze Medal Count, Silver Medal Count, and Gold Medal Count that describe what type and how many medals a country won at the Olympic games that year. Predict the number of silver and gold medals won by that country in that year.\n' --columns_to_ignore Country_Label --csv_split_option fixed_indices --train_start_index 10 --train_end_index 80 --test_start_index 0 --test_end_index 10
python --experiment_name movies --data_path data/movies.csv --llm_type <LLM Type> --output_dir experiments/multi_target_prediction/output --num_samples 1 --batch_size 13 --mode sample_logpy --num_decimal_places_x 1 --num_decimal_places_y 1 --y_column_types numerical categorical categorical categorical categorical categorical categorical categorical categorical --y_column_names Rating Adventure Comedy Family Action Fantasy Thriller Drama Horror --max_generated_length 50 --header_option headers_as_item_prefix --top_k 1 --prefix 'Each example contains 11 columns: Movie Name, Revenue in Millions of Dollars, Rating, and 8 genre tags (Adventure, Comedy, Family, Action, Fantasy, Thriller, Drama, and Horror). Predict the movie rating and genre tags.' --columns_to_ignore 'Release Date' Animation 'Science Fiction' Crime Romance Music History Mystery --csv_split_option fixed_indices --train_start_index 0 --train_end_index 89 --test_start_index 89 --test_end_index 188
python ./experiments/missing_data/ --llm_type <LLM Type> --batch_size <value>
python ./experiments/imputation/ --experiment_name medals_imputation --data_path data/paris_2024_medals.csv --llm_type <LLM Type> --output_dir ./experiments/imputation/output --num_samples 10 --batch_size 5 --mode sample_logpy --num_decimal_places_x 0 --num_decimal_places_y 0 --max_generated_length 40 --header_option headers_as_item_prefix --seed 0 --missing_fraction 0.2 --impute_features True --prefix 'Each row of data contains the name of a country and how many gold, silver and bronze medals that country won at the Paris 2024 olympics.'
You can use your own datasets. The default train/test split is set to 80%/20%, but can be customized. See the CSV input options
To ask questions or report issues, please open an issue on the issues tracker.
If you use this code, please cite our paper:
title={JoLT: Joint Probabilistic Predictions on Tabular Data Using LLMs},
author={Aliaksandra Shysheya and John Bronskill and James Requeima and Shoaib Ahmed Siddiqui and Javier González and David Duvenaud and Richard E. Turner},