This repository implements the Adaptive Sampling with Approximate expected futures (ASAp) algorithm, introduced in the paper Grammar-Aligned Decoding. The ASAP algorithm addresses the issue with GCD techniques (and constrained decoding methods in general), which can distort the LLM's probability distribution.
Clone the repository:
git clone https://github.com/ebmoon/transformers-GAD.git
Create a new Conda environment using the provided requirements file. Replace <env>
with the actual name of your environment:
conda create --name <env> python=3.11
conda activate <env>
pip install -r requirements.txt
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
MODEL_ID = "TinyLlama/TinyLlama_v1.1"
GRAMMAR_PATH = "examples/test/binary_len_5_0.ebnf"
MAX_NEW_TOKENS = 512
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
# Load EBNF grammar
with open(GRAMMAR_PATH, "r") as f:
grammar_str = f.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
# Initialize logits processor for the grammar
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
inf_nan_remove_processor,
gad_oracle_processor,
])
# Tokenize prompt into ids
prompt = "Generate a binary string of length 5"
input_ids = tokenizer([prompt], add_special_tokens=False, return_tensors="pt")["input_ids"]
# Inference Loop
outputs = []
for _ in tqdm(range(10), desc="Running Inference"):
# Generate sequences
output = model.generate(
input_ids,
do_sample=True,
max_new_tokens=MAX_NEW_TOKENS,
logits_processor=logits_processors
)
# Incremental parser state must be reset after each generation
gad_oracle_processor.reset()
# Detokenize generated output
input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
generated_tokens = output.sequences[:, input_length:]
generations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
outputs.append(generations[0])
print(outputs)
The ASAp algorithm is implemented as a logit processor. Users can initialize a new GrammarAlignedOracleLogitsProcessor
for an EBNF grammar and pass it as an argument during generation. Since the logit processor uses an incremental parser internally, users must manually reset the parser state ahead of the next generation the generation.
You can try running scripts/test_gad.py
.
Trained ASAp tries can be saved as a JSON file.
with open(TRIE_PATH, "w") as f:
f.write(gad_oracle_processor.oracle_trie.json())
Saved ASAp tries can be reloaded from a previously saved JSON file and passed during the initialization of theGrammarAlignedOracleLogitsProcessor
. The full example can be checked in scripts/test_gad.py
.
from transformers_gad.oracle.oracle_trie import Trie
with open(TRIE_PATH, "r") as f:
trie = Trie.loads(f.read())
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar, trie)
The full example can be checked in scripts/test_gad_load_trie.py
.
Running scripts in scripts/eval
collects data required for plot.
eval_binary_gad.py
andeval_binary_gcd.py
collect data for the skewed binary grammar example.eval_gad.py
andeval_gcd.py
collect data forSLIA
,BV
andCP
dataset. To specify which dataset to use, you must manually set theSPLIT
variable to either"SLIA"
,"BV"
or"CP"
.
Running scripts in scripts/plot
will generate plots from collected data.
- Again, to specify which dataset to use, you must manually set the
SPLIT
variable to either"binary"
,"SLIA"
,"BV"
or"CP"
. plots/{SPLIT}/prob
contains plots for expectationsplots/{SPLIT}/kl
contains plots for the KL divergenceplots/{SPLIT}
contains all the scatter plots
@misc{gad2024,
title={Grammar-Aligned Decoding},
author={Kanghee Park and Jiayu Wang and Taylor Berg-Kirkpatrick and Nadia Polikarpova and Loris D'Antoni},
year={2024},
eprint={2405.21047},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2405.21047},
}
This project is built upon the transformers-CFG project.