This is the official project page for "Task-agnostic Indexes for Deep Learning-based Queries over Unstructured Data"
Please read the paper for full technical details.
Install the requitements with pip install -r requirements.txt
. You will also need (via pip install -e .
):
DO NOT use numba==0.50.1
since there is a bug.
To reproduce the experiments, your machine will need:
- 300+GB of memory
- 500+GB of space
- GPU (e.g., NVIDIA V100, TITAN V, or later)
On other datasets, hardware requirements will vary.
If you want to reproduce the SIGMOD experiments, use Python 3.8.13 and a conda environment to install tasti.yml
. You'll also need to install blazeit
, supg
, and tasti
as described below.
Otherwise, the following steps will install the necessary packages from scratch:
git clone https://github.com/stanford-futuredata/swag-python.git
cd swag-python/
conda install -c conda-forge opencv
pip install -e .
cd ..
git clone https://github.com/stanford-futuredata/blazeit.git
cd blazeit/
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
conda install -c conda-forge pyclipper
pip install -e .
cd ..
git clone https://github.com/stanford-futuredata/supg.git
cd supg/
pip install pandas feather-format
pip install -e .
cd ..
git clone https://github.com/stanford-futuredata/tasti.git
cd tasti/
pip install -r requirements.txt
pip install -e .
We provide code for creating a TASTI for the night-street
video dataset along with all the queries mentioned in the paper (aggregation, limit, SUPG, position, etc). You can download the night-street
video dataset here. Download the 2017-12-14.zip
and 2017-12-17.zip
files. Unzip the files and place the video data in /lfs/1/jtguibas/data
(feel free to change this path in night_street_offline.py
). For speed purposes, the target dnn will not run in realtime and we have instead provided the outputs here. Place the csv files in /lfs/1/jtguibas/data
. Then, you can reproduce the experiments by running:
python tasti/examples/night_street_offline.py
We also provide an online version of the code that allows you to run the target dnn in realtime. For efficiency purposes, we use Mask R-CNN ResNet-50 FPN as the target dnn. However, the actual model used in the paper is the Mask R-CNN X 152 model available in detectron2. We encourage you to replace the inference with TensorRT or another model serving system for more serious needs.
To run the WikiSQL example, download the data here and place train.jsonl
in /lfs/1/jtguibas/data
(again, feel free to change this path inside wikisql_offline.py
).
To run the CommonVoice example, download the data here and install audioset_tagging_cnn
from here. You will also need to download the Cnn10
and ResNet22
models from there. After, adjust the paths in commonvoice_offline.py
accordingly and make sure to install any missing dependencies.
Our code allows for you to create your own TASTI. You will have to inherit the tasti.Index
class and implement a few functions:
import tasti
class MyIndex(tasti.Index):
def is_close(self, a, b):
'''
Return a Boolean of whether records a and b are 'close'.
'''
raise NotImplementedError
def get_target_dnn_dataset(self, train_or_test='train'):
'''
Return a torch.utils.data.Dataset object.
'''
raise NotImplementedError
def get_embedding_dnn_dataset(self, train_or_test='train'):
'''
Return a torch.utils.data.Dataset object.
'''
raise NotImplementedError
def get_target_dnn(self):
'''
Return a torch.nn.Module object.
'''
raise NotImplementedError
def get_embedding_dnn(self):
'''
Return a torch.nn.Module object.
'''
raise NotImplementedError
def get_pretrained_embedding_dnn(self):
'''
Optional if do_mining is False.
Return a torch.nn.Module object.
'''
raise NotImplementedError
def target_dnn_callback(self, target_dnn_output):
'''
Optional if you don't want to process the target_dnn_output.
'''
return target_dnn_output
def override_target_dnn_cache(self, target_dnn_cache, train_or_test='train'):
'''
Optional if you want to run the target dnn in realtime.
Allows for you to override the target_dnn_cache when you have the
target dnn outputs already cached.
'''
raise NotImplementedError
class MyQuery(tasti.AggregateQuery):
def score(self, target_dnn_output):
'''
Maps a target_dnn_output into a feature/scalar you are interested in.
Note that this is an aggregate query, so this query will try to estimate the total sum of these scores.
'''
return len(target_dnn_output)
config = tasti.IndexConfig()
config.nb_buckets = 500
index = MyIndex(config)
index.init()
query = MyQuery()
result = query.execute()
print(result)
These are the options available in tasti.IndexConfig
which get passed into the tasti.Index
object.
do_mining
, Boolean that determines whether the mining step is skipped or notdo_training
, Boolean that determines whether the training/fine-tuning step of the embedding dnn is skipped or notdo_infer
, Boolean that allows you to either compute embeddings or load them from./cache
do_bucketting
, Boolean that allows you to compute the buckets or load them from./cache
batch_size
, general batch size for both the target and embedding dnntrain_margin
, controls the margin parameter of the triplet lossmax_k
, controls the k parameter described in the paper (for computing distance weighted means and votes)nb_train
, controls how many datapoints are labeled to perform the triplet trainingnb_buckets
, controls the number of buckets used to construct the indexnb_training_its
, controls the number of datapoints are passed through the model during training