-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support arbitrary json_object in OpenAI and Context Free Grammar #3211
Changes from 6 commits
e0f51ec
27eadd9
05876a2
e039f43
90db0e5
245ac00
16f34ef
036d283
cb00513
be18681
d369879
7290ea7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,30 +16,59 @@ | |
import json | ||
import math | ||
from collections import defaultdict | ||
from typing import Union, DefaultDict, Dict, List, Optional | ||
from typing import Union, DefaultDict, Dict, List, Optional, Callable | ||
|
||
import torch | ||
from pydantic import BaseModel | ||
from outlines.fsm.fsm import RegexFSM | ||
from outlines.fsm.fsm import RegexFSM, CFGFSM | ||
from outlines.fsm.json_schema import build_regex_from_schema | ||
|
||
|
||
class RegexLogitsProcessor: | ||
class BaseLogitsProcessor: | ||
|
||
def __init__(self, regex_string: str, tokenizer): | ||
"""Compile the FSM that drives the regex-structured generation. | ||
def adapt_tokenizer(self, tokenizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"""Adapt vLLM's tokenizer to use to compile the FSM. | ||
|
||
Parameters | ||
---------- | ||
regex_string | ||
A string that represents a regular expression | ||
tokenizer | ||
The model's tokenizer | ||
The API of Outlines tokenizers is slightly different to that of | ||
`transformers`. The decoder of outlines, returns a list whereas | ||
the decode of vLLM returns an str. To sync the vLLM decoder with | ||
outlines internal api, the decoder should be adapted. In addition | ||
we need to handle the missing spaces to Llama's tokenizer to be | ||
able to compile FSMs for this model. | ||
|
||
""" | ||
tokenizer = self.adapt_tokenizer(tokenizer) | ||
fsm = RegexFSM(regex_string, tokenizer) | ||
self.fsm = fsm | ||
if getattr(tokenizer, "_outlines_adapted", False): | ||
return tokenizer | ||
|
||
tokenizer.vocabulary = tokenizer.get_vocab() | ||
tokenizer.special_tokens = set(tokenizer.all_special_tokens) | ||
|
||
def convert_token_to_string(token: str) -> str: | ||
from transformers.file_utils import SPIECE_UNDERLINE | ||
|
||
string = tokenizer.convert_tokens_to_string([token]) | ||
|
||
# A hack to handle missing spaces to HF's Llama tokenizers | ||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": | ||
return " " + string | ||
|
||
return string | ||
|
||
def change_decoder( | ||
decoder: Callable[[List[int]], str] | ||
) -> Callable[[List[int]], List[str]]: | ||
"""Sync vLLM's decoder with the outlines expectations by returning list""" | ||
|
||
def new_decoder(inp_tokens: List[int]) -> List[str]: | ||
return [decoder(inp_tokens)] | ||
|
||
return new_decoder | ||
|
||
tokenizer.convert_token_to_string = convert_token_to_string | ||
tokenizer.decode = change_decoder(tokenizer.decode) | ||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 | ||
|
||
return tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need return this? |
||
|
||
def init_state(self): | ||
"""Initialize the FSM states.""" | ||
|
@@ -69,31 +98,23 @@ def __call__(self, input_ids: List[int], | |
|
||
return scores | ||
|
||
def adapt_tokenizer(self, tokenizer): | ||
"""Adapt vLLM's tokenizer to use to compile the FSM. | ||
|
||
The API of Outlines tokenizers is slightly different to that of | ||
`transformers`. In addition we need to handle the missing spaces to | ||
Llama's tokenizer to be able to compile FSMs for this model. | ||
|
||
""" | ||
tokenizer.vocabulary = tokenizer.get_vocab() | ||
tokenizer.special_tokens = set(tokenizer.all_special_tokens) | ||
|
||
def convert_token_to_string(token: str) -> str: | ||
from transformers.file_utils import SPIECE_UNDERLINE | ||
|
||
string = tokenizer.convert_tokens_to_string([token]) | ||
|
||
# A hack to handle missing spaces to HF's Llama tokenizers | ||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": | ||
return " " + string | ||
class RegexLogitsProcessor(BaseLogitsProcessor): | ||
|
||
return string | ||
def __init__(self, regex_string: str, tokenizer): | ||
"""Compile the FSM that drives the regex-structured generation. | ||
|
||
tokenizer.convert_token_to_string = convert_token_to_string | ||
Parameters | ||
---------- | ||
regex_string | ||
A string that represents a regular expression | ||
tokenizer | ||
The model's tokenizer | ||
|
||
return tokenizer | ||
""" | ||
tokenizer = self.adapt_tokenizer(tokenizer) | ||
fsm = RegexFSM(regex_string, tokenizer) | ||
self.fsm = fsm | ||
|
||
|
||
class JSONLogitsProcessor(RegexLogitsProcessor): | ||
|
@@ -127,3 +148,21 @@ def __init__(self, | |
+ "Schema specification") | ||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern) | ||
super().__init__(regex_string, tokenizer) | ||
|
||
|
||
class CFGLogitsProcessor(BaseLogitsProcessor): | ||
|
||
def __init__(self, cfg: str, tokenizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding type definition for |
||
"""Compile the FSM that drives the CFG-structured generation. | ||
|
||
Parameters | ||
---------- | ||
cfg | ||
A string that represents a context-free grammar | ||
tokenizer | ||
The model's tokenizer | ||
|
||
""" | ||
tokenizer = self.adapt_tokenizer(tokenizer) | ||
fsm = CFGFSM(cfg, tokenizer) | ||
self.fsm = fsm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dumb question: does
CFG
meanconfig
? only for grammar config?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context free grammar, good point that I should probably change it to
grammar
to disambiguate