Skip to content

Commit

Permalink
Merge pull request #33 from remichu-ai/format_enforcement_enhancement
Browse files Browse the repository at this point in the history
Format enforcement enhancement
  • Loading branch information
remichu-ai authored Sep 28, 2024
2 parents d164503 + a8cdc09 commit 82e1708
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 104 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "gallama"
version = "0.0.5"
version = "0.0.6"
description = "An oppinionated Llama Server engine with focus on agentic task"
authors = [{name = "David", email = "trantrungduc91@example.com"}]
license = {text = "MIT"}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ pyzmq
pygments
httpx
psutil
formatron
formatron>=0.4.4
129 changes: 34 additions & 95 deletions src/gallama/backend/chatgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import HTTPException, Request
from .model import Model
from gallama.data_classes.data_class import GenerationStats, GenEnd, GenText, GenQueue, ChatMLQuery, GenStart
from .tools import Tools, create_function_models_v2
from .tools import Tools, create_function_models_v2, create_function_models_formatron
from dataclasses import dataclass
from gallama.utils.utils import get_token_length
from gallama.logger.logger import logger
Expand All @@ -19,6 +19,17 @@
from lmformatenforcer.tokenenforcer import TokenEnforcerTokenizerData
from concurrent.futures import ThreadPoolExecutor
from importlib.metadata import version
from .format_enforcer import FormatEnforcer
from formatron.schemas.pydantic import ClassSchema

try:
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter

except:
FormatterBuilder = None
create_formatter_filter = None


try:
from exllamav2 import (
Expand Down Expand Up @@ -64,14 +75,6 @@

assert ExLlamaV2Cache or LogitsProcessorList, "Please install ExllamaV2 or LLama CPP Python as backend"

# experimental support for formatron
try:
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter

except:
FormatterBuilder = None
create_formatter_filter = None

TOOL_THINKING = THINKING_TEMPLATE["tool_necessity_evaluation"]
TOOL_FORCE_THINKING = THINKING_TEMPLATE["tool_forced_evaluation"]
Expand All @@ -94,71 +97,6 @@ def get_queue(self) -> GenQueue | None:
return self.gen_queue()


class FormatEnforcer:
""" this class will help to create filter for generation enforcement"""

def __init__(self):
pass

@staticmethod
def get_default_engine(backend:str = "exllama") -> Literal["formatron", "lm_enforcer"]:
""" this function will select the format enforcer engine to use if not selected by user"""

# formatron doesnt support llama cpp at the moment
if backend == "llama_cpp":
return "lm_enforcer"
elif backend == "exllama":
# use formatron if it is available if it is exllama
if FormatterBuilder:
return "formatron"
else:
# return "formatron"
return "lm_enforcer"
else:
raise "Invalid backend"

# return "lm_enforcer"


def regex(self, regex_pattern: str, filter_engine: Literal[
"formatron", "lm_enforcer"] = None, backend: str = "exllama") -> FormatterBuilder | TokenEnforcerTokenizerData:
logger.info(backend)
# set the filter engine to use
if not filter_engine:
filter_engine = FormatEnforcer.get_default_engine(backend=backend) # if engine is specified, use it

# create filter if engine is lm_enforcer
if filter_engine == "lm_enforcer":
return RegexParser(regex_pattern)

# create filter if engine is formatron
if filter_engine == "formatron":
f = FormatterBuilder()
_regex = f.regex(regex_pattern, capture_name='regex')
f.append_line(f"{_regex}")
return f

def json(self, pydantic_model, filter_engine: Literal[
"formatron", "lm_enforcer"] = None, backend: str = "exllama") -> FormatterBuilder | TokenEnforcerTokenizerData:
""" this function will return the filters for format enforcer to generate json output based on Pyantic model"""

# set the filter engine to use
if not filter_engine:
filter_engine = FormatEnforcer.get_default_engine(backend=backend) # if engine is specified, use it

# create filter if engine is lm_enforcer
if filter_engine == "lm_enforcer" or filter_engine == "formatron": # TODO currently formatron and nested pydantic model is having issue
# if filter_engine == "lm_enforcer": # TODO currently formatron and nested pydantic model is having issue
json_schema = Tools.replace_refs_with_definitions_v2(pydantic_model.model_json_schema())
return JsonSchemaParser(json_schema)

# # create filter if engine is formatron
# if filter_engine == "formatron":
# f = FormatterBuilder()
# f.append_line(f"{f.json(pydantic_model, capture_name='json')}")
# return f


class ChatGenerator(Model):
def __init__(
self,
Expand Down Expand Up @@ -238,9 +176,9 @@ async def chat_no_tool(self, query: ChatMLQuery, prompt_eng, gen_queue, request:
)

formatter_prefix_regex = self.formatter.regex(
query.regex_prefix_pattern, backend=self.backend) if query.regex_prefix_pattern else None
query.regex_prefix_pattern, backend=self.backend, preference=query.guided_decoding_backend) if query.regex_prefix_pattern else None

formatter_regex = self.formatter.regex(query.regex_pattern, backend=self.backend) if query.regex_pattern else None
formatter_regex = self.formatter.regex(query.regex_pattern, backend=self.backend, preference=query.guided_decoding_backend) if query.regex_pattern else None

token_length_prompt = get_token_length(self.tokenizer, prompt)
self.validate_token_length(token_length_prompt)
Expand Down Expand Up @@ -445,7 +383,7 @@ async def chat_with_tool(self, query: ChatMLQuery, prompt_eng, gen_queue, reques
# perform generation with tool thinking to evaluate if it is necessity
tool_thinking_queue_fallback = GenQueue()

formatter_regex = self.formatter.regex('(needed|not needed)', backend=self.backend)
formatter_regex = self.formatter.regex('(needed|not needed)', backend=self.backend, preference=query.guided_decoding_backend)

await self.generate(
prompt,
Expand All @@ -470,25 +408,24 @@ async def chat_with_tool(self, query: ChatMLQuery, prompt_eng, gen_queue, reques
# USE TOOL
if use_tool_bool:
# create the pydantic schema to enforce generation
tool_combined_pydantic = create_function_models_v2(tool_handler.tool_dict)
tool_combined_pydantic_lmfe = create_function_models_v2(tool_handler.tool_dict)

class ToolCalling(BaseModel):
class ToolCalling_LMFE(ClassSchema):
""" The format to call one or multiple tools """
functions_calling: List[Union[tuple(tool_combined_pydantic)]] = Field(
description='the list of functions to call in chronological order',
default=[]
)

# class ItemModel(BaseModel):
# Use: Literal['Yes', 'No']
# reason: str
functions_calling: List[Union[tuple(tool_combined_pydantic_lmfe)]] = []

# answer_format_schema = tool_handler.replace_refs_with_definitions_v2(ToolCalling.schema())
#
# # get format enforcer
# formatter = JsonSchemaParser(answer_format_schema)
# create the pydantic schema to enforce generation for formatron which use ClassSchema
tool_combined_pydantic_formatron = create_function_models_formatron(tool_handler.tool_dict_formatron)
class ToolCalling_formatron(ClassSchema):
""" The format to call one or multiple tools """
functions_calling: List[Union[tuple(tool_combined_pydantic_formatron)]] = []

formatter_json = self.formatter.json(pydantic_model=ToolCalling, backend=self.backend)
formatter_json = self.formatter.json(
pydantic_model_lmfe=ToolCalling_LMFE,
pydantic_model_formatron=ToolCalling_formatron,
backend=self.backend,
preference = query.guided_decoding_backend
)

# Experiment feature, formulate function calling as python programming. Which is more natural than a random Json output as part of conversation
tool_as_code_prompt = """
Expand Down Expand Up @@ -782,11 +719,13 @@ async def generate(

# Depending on settings, the result dict can contain top-K probabilities, logits and more, but we'll just
# grab the output text stream.
# generate_text += result.get("text", "")
# logger.info(f'{datetime.now()} {result.get("text", "")}')
chunk = GenText(content=result.get("text", ""), text_type=gen_type_str)
chunk_text = result.get("text", "")
chunk = GenText(content=chunk_text, text_type=gen_type_str)
for g_queue in gen_queue_list:
g_queue.get_queue().put_nowait(chunk)
if chunk_text not in self.eos_token_str_set: # formatron return eos token
# generate_text += result.get("text", "")
g_queue.get_queue().put_nowait(chunk)

# logger.info(result.get("text", ""))
# logger.info(self.tokenizer.encode(result.get("text", "")))
Expand Down
109 changes: 109 additions & 0 deletions src/gallama/backend/format_enforcer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from logging import raiseExceptions
from typing import List, Union, Literal, Optional
from formatron.schemas.pydantic import ClassSchema
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.tokenenforcer import TokenEnforcerTokenizerData
from gallama.logger.logger import logger
from pydantic import BaseModel
from .tools import Tools


# experimental support for formatron
try:
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter

except:
FormatterBuilder = None
create_formatter_filter = None

class FormatEnforcer:
""" this class will help to create filter for generation enforcement"""

def __init__(self):
pass

@staticmethod
def get_default_engine(
backend:str = "exllama",
preference: Literal["auto", "formatron", "lm-format-enforcer"] = "auto",
) -> Literal["formatron", "lm_enforcer"]:
""" this function will select the format enforcer engine to use if not selected by user"""

if preference != "auto":
logger.info(f"guided encoding preference: {preference}")

# formatron doesnt support llama cpp at the moment
if backend == "llama_cpp":
return "lm_enforcer"
elif backend == "exllama":
# use formatron if it is available if it is exllama
if preference == "auto":
if FormatterBuilder:
return "formatron"
else:
return "lm_enforcer"
else:
if preference == "formatron" and FormatterBuilder:
return "formatron"
elif preference == "lm-format-enforcer":
return "lm_enforcer"
else:
raise "Invalid backend"
else:
raise "Invalid backend"



def regex(
self,
regex_pattern: str,
filter_engine: Literal[
"formatron", "lm_enforcer"] = None,
backend: str = "exllama",
preference: Literal["auto", "formatron", "lm-format-enforcer"] = "auto",
) -> FormatterBuilder | TokenEnforcerTokenizerData:

logger.info(backend)
# set the filter engine to use
if not filter_engine:
filter_engine = FormatEnforcer.get_default_engine(backend=backend, preference=preference) # if engine is specified, use it

# create filter if engine is lm_enforcer
if filter_engine == "lm_enforcer":
return RegexParser(regex_pattern)

# create filter if engine is formatron
if filter_engine == "formatron":
f = FormatterBuilder()
_regex = f.regex(regex_pattern, capture_name='regex')
f.append_line(f"{_regex}")
return f

def json(
self,
pydantic_model_lmfe: BaseModel,
pydantic_model_formatron: ClassSchema,
filter_engine: Literal["formatron", "lm_enforcer"] = None,
backend: Literal["llama_cpp", "exllama"] = "exllama",
preference: Literal["auto", "formatron", "lm-format-enforcer"] = "auto",
) -> FormatterBuilder | TokenEnforcerTokenizerData:
""" this function will return the filters for format enforcer to generate json output based on Pyantic model"""

# set the filter engine to use
if not filter_engine:
filter_engine = FormatEnforcer.get_default_engine(backend=backend, preference=preference) # if engine is specified, use it

assert filter_engine == "lm_enforcer" or filter_engine == "formatron"

# create filter if engine is lm_enforcer
# if filter_engine == "lm_enforcer" or filter_engine == "formatron": # TODO currently formatron and nested pydantic model is having issue
if filter_engine == "lm_enforcer": # TODO currently formatron and nested pydantic model is having issue
json_schema = Tools.replace_refs_with_definitions_v2(pydantic_model_lmfe.model_json_schema())
return JsonSchemaParser(json_schema)

# create filter if engine is formatron
elif filter_engine == "formatron":
f = FormatterBuilder()
f.append_line(f"{f.json(pydantic_model_formatron, capture_name='json')}")
return f
1 change: 1 addition & 0 deletions src/gallama/backend/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self,
# TODO, to auto detect
# get the eos_token_str by merging the default config with anything set by user
self.eos_token_str = list(set(model_config.get("eos_token_list", []) + eos_token_list_from_prompt_template))
self.eos_token_str_set = set(self.eos_token_str) # set for some more efficient operation
self.eos_token_ids = self.generate_eos_tokens_id()


Expand Down
Loading

0 comments on commit 82e1708

Please sign in to comment.