-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_gad_load_trie.py
90 lines (77 loc) · 2.82 KB
/
test_gad_load_trie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList, InfNanRemoveLogitsProcessor
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
from transformers_gad.generation.logits_process import GrammarAlignedOracleLogitsProcessor
from transformers_gad.oracle.oracle_trie import Trie
NUM_ITER = 10
MODEL_ID = "TinyLlama/TinyLlama_v1.1"
GRAMMAR_PATH = "examples/test/binary_len_5_0.ebnf"
TRIE_PATH = "tries/binary_len_5_0_trie.json"
DEVICE = "cpu"
DTYPE = torch.bfloat16
MAX_NEW_TOKENS = 512
TEMPERATURE = 1.0
REPETITION_PENALTY = 1.0
TOP_P = 1.0
TOP_K = 0
device = torch.device(DEVICE)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
model.to(device)
model.to(dtype=DTYPE)
model.resize_token_embeddings(len(tokenizer))
# Load EBNF grammar
with open(GRAMMAR_PATH, "r") as file:
grammar_str = file.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
# JSON must be generated by running test_gad.py
with open(TRIE_PATH, "r") as f:
trie = Trie.loads(f.read())
# Initialize logits processor for the grammar, with the pretrained trie
gad_oracle_processor = GrammarAlignedOracleLogitsProcessor(grammar, trie)
inf_nan_remove_processor = InfNanRemoveLogitsProcessor()
logits_processors = LogitsProcessorList([
inf_nan_remove_processor,
gad_oracle_processor,
])
# Tokenize prompt into ids
prompt = "Generate a binary string of length 5"
input_ids = tokenizer(
[prompt], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"]
input_ids = input_ids.to(model.device)
# Inference Loop
outputs = []
for _ in tqdm(range(NUM_ITER), desc="Running Inference"):
# Generate sequences
output = model.generate(
input_ids,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=MAX_NEW_TOKENS,
top_p=TOP_P,
top_k=TOP_K,
temperature=TEMPERATURE,
logits_processor=logits_processors,
repetition_penalty=REPETITION_PENALTY,
num_return_sequences=1,
return_dict_in_generate=True,
output_scores=True,
)
# Incremental parser state must be reset after each generation
gad_oracle_processor.reset()
# Detokenize generate output
input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
generated_tokens = output.sequences[:, input_length:]
generations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
outputs.append(generations[0])
print(outputs)
# Store the trie as JSON
with open(TRIE_PATH, "w") as f:
f.write(gad_oracle_processor.oracle_trie.json())