Skip to content

Commit

Permalink
Refactored TokenEnforcer + TokenizerPrefixTree to be library-agnostic…
Browse files Browse the repository at this point in the history
…, moved all of the pytorch / transformers related code to transformerenforcer.py
  • Loading branch information
noamgat committed Oct 10, 2023
1 parent f51b50c commit e351e29
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 69 deletions.
4 changes: 3 additions & 1 deletion lmformatenforcer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
__all__ = ['CharacterLevelParser',
'StringParser',
'RegexParser',
'JsonSchemaParser',
'JsonSchemaParser',
'TokenEnforcer',
'generate_enforced']

from .characterlevelparser import CharacterLevelParser, StringParser
from .regexparser import RegexParser
from .jsonschemaparser import JsonSchemaParser
from .tokenenforcer import TokenEnforcer

try:
from .transformerenforcer import generate_enforced
Expand Down
42 changes: 17 additions & 25 deletions lmformatenforcer/tokenenforcer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from dataclasses import dataclass, field
from typing import Dict, Hashable, List, Optional, Union
from typing import Callable, Dict, Hashable, List, Optional, Tuple
import logging

from numpy import rec
from .characterlevelparser import CharacterLevelParser, ForceStopParser
from .jsonschemaparser import JsonSchemaParser
from transformers.tokenization_utils import PreTrainedTokenizerBase
from .external.jsonschemaobject import JsonSchemaObject

from .tokenizerprefixtree import TokenizerPrefixTree, TokenizerPrefixTreeNode


Expand All @@ -18,24 +12,22 @@ class OutputTensorState:
parser: CharacterLevelParser
allowed_tokens: List[int] = field(default_factory=list)

def __init__(self, tokenizer: PreTrainedTokenizerBase, parser: CharacterLevelParser):
self.tokenizer = tokenizer
self.token_0 = tokenizer.encode("0")[-1]
def __init__(self, regular_tokens: List[Tuple[int, str]],
parser: CharacterLevelParser,
decoder: Callable[[List[int]], str],
eos_token_id: int):
self.prefix_states: Dict[Hashable, TokenEnforcer.OutputTensorState] = {}
self.root_parser = parser
self.tokenizer_tree = TokenizerPrefixTree(tokenizer)

def _decode_single_token(self, token: int) -> str:
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded = self.tokenizer.decode([self.token_0, token])[1:]
return decoded
self.tokenizer_tree = TokenizerPrefixTree(regular_tokens)
self.decoder = decoder
self.eos_token_id = eos_token_id

def filter_allowed_tokens(self, batch_id: int, sent: 'torch.Tensor') -> List[int]:
def get_allowed_tokens(self, token_sequence: List[int]) -> List[int]:
# In order to elegantly support beam search and batching, we don't store per-batch information.
# Instead, we store a hash of all the states (unique token tensors) we encountered so far.
# When we encounter a new unique token tensor, we find the token tensor that led to it, and continue from there.

sent_tuple = tuple(sent.tolist())
sent_tuple = tuple(token_sequence)
prev_step_tuple = sent_tuple[:-1]

if sent_tuple in self.prefix_states:
Expand All @@ -44,15 +36,15 @@ def filter_allowed_tokens(self, batch_id: int, sent: 'torch.Tensor') -> List[int
elif prev_step_tuple not in self.prefix_states:
# We have not encountered the tensor up to the before-last entry. This means that this is the first call - the instruction / prompt tensor.
# Initialize the root node
state = TokenEnforcer.OutputTensorState(str_so_far=self.tokenizer.decode(sent),
state = TokenEnforcer.OutputTensorState(str_so_far=self.decoder(token_sequence),
parser=self.root_parser)
self.prefix_states[sent_tuple] = state
self._compute_allowed_tokens(state)
return state.allowed_tokens
else:
# Find the state that led to this node. We explicitly don't use the concept of "timestep" because of beam search
prev_step_state = self.prefix_states[prev_step_tuple]
new_state = self._apply_new_characters(prev_step_state, sent)
new_state = self._apply_new_characters(prev_step_state, token_sequence)
self.prefix_states[sent_tuple] = new_state
self._compute_allowed_tokens(new_state)
return new_state.allowed_tokens
Expand All @@ -62,7 +54,7 @@ def _compute_allowed_tokens(self, state: 'TokenEnforcer.OutputTensorState'):
shortcut_key = state.parser.shortcut_key()
self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key)
if state.parser.can_end():
allowed_tokens.append(self.tokenizer.eos_token_id)
allowed_tokens.append(self.eos_token_id)
if not allowed_tokens:
raise ValueError(f"Parser reached state with no allowed tokens")
# root_state = next(state for state in self.prefix_states.values() if state.parser == self.root_parser)
Expand All @@ -88,16 +80,16 @@ def _collect_allowed_tokens(self, parser: CharacterLevelParser, tree_node: Token
next_tree_node = tree_node.children[character]
self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None)


def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', sent: 'torch.Tensor'):
characters = self.tokenizer.decode(sent)
def _apply_new_characters(self, state: 'TokenEnforcer.OutputTensorState', token_sequence: List[int]):
characters = self.decoder(token_sequence)
new_state = TokenEnforcer.OutputTensorState(str_so_far=characters, parser=state.parser)
new_characters = characters[len(state.str_so_far):]
for character in new_characters:
if character in new_state.parser.get_allowed_characters():
new_state.parser = new_state.parser.add_character(character)
else:
logging.warning(f"Received an invalid character '{character}', switching to ForceStopParser")
# This can happen in beam / batch scenarios, when some of the batches finished but others are continuing.
logging.debug(f"Received an invalid character '{character}', switching to ForceStopParser")
new_state.parser = ForceStopParser()
return new_state

Expand Down
19 changes: 5 additions & 14 deletions lmformatenforcer/tokenizerprefixtree.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from typing import Dict, List
from transformers.tokenization_utils import PreTrainedTokenizerBase
from typing import Dict, List, Tuple


class TokenizerPrefixTreeNode:
def __init__(self):
self.tokens: List[int] = []
self.children: Dict[str, TokenizerPrefixTreeNode] = {}


class TokenizerPrefixTree:
def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.tokenizer = tokenizer
self.token_0 = tokenizer.encode("0")[-1]
def __init__(self, regular_tokens: List[Tuple[int, str]]):
self.root = TokenizerPrefixTreeNode()
self.json_freetext_tokens: List[int] = []
for token_idx in range(self.tokenizer.vocab_size):
if token_idx in self.tokenizer.all_special_ids:
continue
decoded = self._decode_single_token(token_idx)
for token_idx, decoded in regular_tokens:
self._add_token_to_tree(decoded, token_idx, self.root)
# Performance optimization - cache the tokens of all the strings that don't contain a quote in the middle.
# When we are in a JSON freetext string field, they will all be permitted and this will save a lot of tree iterations.
Expand All @@ -28,8 +24,3 @@ def _add_token_to_tree(self, token_str: str, token_idx: int, node: TokenizerPref
node.children[character] = TokenizerPrefixTreeNode()
node = node.children[character]
node.tokens.append(token_idx)

def _decode_single_token(self, token: int) -> str:
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded = self.tokenizer.decode([self.token_0, token])[1:]
return decoded
22 changes: 20 additions & 2 deletions lmformatenforcer/transformerenforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,34 @@ def get_leading_scores(self) -> Tuple[List[int], List[float]]:
return best_tokens.tolist(), token_probs_list


def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str]]:
token_0 = tokenizer.encode("0")[-1]
regular_tokens = []
for token_idx in range(tokenizer.vocab_size):
if token_idx in tokenizer.all_special_ids:
continue
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
decoded = tokenizer.decode([token_0, token_idx])[1:]
regular_tokens.append((token_idx, decoded))
return regular_tokens


def generate_enforced(model: AutoModelForCausalLM,
tokenizer: PreTrainedTokenizerBase,
character_level_parser: CharacterLevelParser,
**kwargs: dict) -> Union[str, dict]:
token_enforcer = TokenEnforcer(tokenizer, character_level_parser)

regular_tokens = _build_regular_tokens_list(tokenizer)
token_enforcer = TokenEnforcer(regular_tokens, character_level_parser, tokenizer.decode, tokenizer.eos_token_id)

def transformers_filter_allowed_tokens(batch_id: int, sent: torch.Tensor) -> List[int]:
token_sequence = sent.tolist()
return token_enforcer.get_allowed_tokens(token_sequence)

is_multi_inputs = kwargs['input_ids'].shape[0] > 1
is_multi_beams = kwargs.get('num_beams', 1) > 1
logits_saver = LogitsSaverManager(model)
logits_saver.replace_logits_warper(token_enforcer.filter_allowed_tokens)
logits_saver.replace_logits_warper(transformers_filter_allowed_tokens)
generate_kwargs = kwargs
return_dict_in_generate = kwargs.get('return_dict_in_generate', False)
output_scores = kwargs.get('output_scores', None)
Expand Down
39 changes: 12 additions & 27 deletions samples/colab_llama2_enforcer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"# import sys\n",
"# import os\n",
"# sys.path.append(os.path.abspath('..'))\n",
"# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
"## os.environ['CUDA_LAUNCH_BLOCKING'] = '1'"
]
},
{
Expand All @@ -55,16 +55,9 @@
"text": [
"/home/noamgat/mambaforge/envs/commentranker/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [03:44<00:00, 112.07s/it]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [05:29<00:00, 164.58s/it]\n",
"Using pad_token, but it is not set yet.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"OK\n"
]
}
],
"source": [
Expand Down Expand Up @@ -136,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -226,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -297,22 +290,14 @@
"data": {
"text/markdown": [
"```\n",
" Of course, I'd be happy to provide information about Michael Jordan using the provided JSON schema. Here's the response:\n",
" Of course! I'd be happy to provide information about Michael Jordan using the provided JSON schema.\n",
"{\n",
"\"title\": \"AnswerFormat\",\n",
"\"type\": \"object\",\n",
"\"properties\": {\n",
"\"first_name\": {\n",
"\"title\": \"First Name\",\n",
"\"type\": \"string\",\n",
"\"required\": true\n",
"\n",
"},\n",
"\"last_name\": {\n",
"\n",
"\"title\": \"Last Name\",\n",
"\n",
"\"type\":\n",
"\"first_name\": {\"title\": \"First Name\", \"type\": \"string\"},\n",
"\"last_name\": {\"title\": \"Last Name\", \"type\": \"string\"},\n",
"\"year_of_birth\": {\"title\": \"Year Of Birth\", \"\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -372,7 +357,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -961,7 +946,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1006,7 +991,7 @@
"data": {
"text/markdown": [
"```\n",
" Thank you for your question! Michael Jordan was born in the year 1963.\n",
" Thank you for asking! Michael Jordan was born in the year 1963.\n",
"```"
],
"text/plain": [
Expand Down Expand Up @@ -1484,7 +1469,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit e351e29

Please sign in to comment.