This repository contains the code related to the experiments in the paper Loki: Low-Rank Keys for Efficient Sparse Attention. We provide the code to compute the PCA of the keys for various models, baseline method implementations and kernels for Loki used in the paper, along with scripts to evaluate the methods on perplexity evaluation and downstream tasks.
You need to install the requirements as follows:
pip install -r requirements.txt
Note: The code requires specific versions of the huggingface transformers library present in the requirements.txt file. It will not work with other versions.
Say you want to compute the PCA transform for the keys of Llama-2-7b model. You can do this by following the steps below:
-
Run perplexity evaluation on the model on a target dataset to save the keys, queries and values tensors.
# The --use-axonn flag is optional and is used to shard the model over multiple GPUs using AxoNN python -u evaluate_tasks.py --sequence-length 4096 --model-id meta-llama/Llama-2-7b-hf --model-type llama --dataset wikitext-valid --save-tensors --tensor-dir <Directory to save tensors> --use-topk --top-k 1 [--use-axonn]
List of possible datasets - wikitext-valid, bookcorpus, c4
-
Compute the PCA of the generated keys: In the
pca_analysis
directory, run the following command:python pca.py key <NUM_LAYERS> <Path to saved key tensors> <Path to output the PCA transforms>
Verify that the PCA transform are saved in the output directory. Do not modify the subdirectory structure of the output directory as it is used by the downstream tasks evaluation code.
Once the PCA transform is computed, we can run the ML evaluations using Loki. The following command runs the evaluation on the downstream tasks using the PCA transform computed in the previous step:
python -u evaluate_tasks.py \
--sequence-length 4096 \
--model-id meta-llama/Llama-2-7b-hf \
--model-type llama
--use-pca-topk
--top-r <16/32/64>
--top-k <0.125/0.25/0.5> \
--rotary-type <prerotary/postrotary> \
--dataset <Dataset to compute perplexity on, Default: wikitext-test> \
--transform-dataset <Dataset used to compute PCA: wikitext/bookcorpus/c4, Default:wikitext> \
[--lm-harness-eval] \ # Flag to evaluate the model on the LM Harness Tasks
[--use-wandb] \ # Optional flag to log the results to wandb
[--use-axonn] # Optional flag to shard the model over multiple GPUs using AxoNN
To run the compute evaluation, you can use the following command:
python evaluate_compute.py
This will run the attention benchmark with Loki and vanilla attention assuming a Llama2-13B type model and save the results in a compute_files
directory.