Skip to content

Update README.md to comply with the latest version of TRT-LLM and outlines #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions AI_Agents_Guide/Constrained_Decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,17 @@ to serve your TensorRT-LLM model.), custom logits processor should be specified
during model's initialization as a part of
[Executor's](https://nvidia.github.io/TensorRT-LLM/executor.html#executor-api)
configuration
([`logits_post_processor_map`](https://github.com/NVIDIA/TensorRT-LLM/blob/32ed92e4491baf2d54682a21d247e1948cca996e/tensorrt_llm/hlapi/llm_utils.py#L205)).
([`logits_post_processor_map`](https://github.com/NVIDIA/TensorRT-LLM/blob/258c7540c03517def55d9a5aadfa9288af474e1b/tensorrt_llm/llmapi/llm_utils.py#L322)).
Below is the sample for reference.

```diff
...

+ executor_config.logits_post_processor_map = {
+ "<custom_logits_processor_name>": custom_logits_processor
+ }
+ logits_proc_config = trtllm.LogitsPostProcessorConfig()
+ logits_proc_config.processor_map = {
+ "<custom_logits_processor_name>": custom_logits_processor
+ }
+ executor_config.logits_post_processor_config = logits_proc_config
self.executor = trtllm.Executor(model_path=...,
model_type=...,
executor_config=executor_config)
Expand Down Expand Up @@ -331,17 +333,15 @@ def execute(self, requests):
...

for request in requests:
response_sender = request.get_response_sender()
if get_input_scalar_by_name(request, 'stop'):
self.handle_stop_request(request.request_id(), response_sender)
else:
...
try:
converted = convert_request(request,
self.exclude_input_from_output,
self.decoupled)
converted_reqs = convert_request(
request, self.exclude_input_from_output,
self.decoupled)
+ logits_post_processor_name = get_input_tensor_by_name(request, 'logits_post_processor_name')
+ if logits_post_processor_name is not None:
+ converted.logits_post_processor_name = logits_post_processor_name.item().decode('utf-8')
+ for converted in converted_reqs:
+ converted.logits_post_processor_name = logits_post_processor_name.item().decode('utf-8')
except Exception as e:
...
```
Expand Down Expand Up @@ -470,6 +470,10 @@ class TritonPythonModel:
def get_executor_config(self, model_config):
+ tokenizer_dir = model_config['parameters']['tokenizer_dir']['string_value']
+ logits_processor = LMFELogitsProcessor(tokenizer_dir, AnswerFormat.model_json_schema())
+ logits_proc_config = trtllm.LogitsPostProcessorConfig()
+ logits_proc_config.processor_map = {
+ LMFELogitsProcessor.PROCESSOR_NAME: logits_processor
+ }
kwargs = {
"max_beam_width":
get_parameter(model_config, "max_beam_width", int),
Expand All @@ -490,9 +494,7 @@ class TritonPythonModel:
self.get_peft_cache_config(model_config),
"decoding_config":
self.get_decoding_config(model_config),
+ "logits_post_processor_map":{
+ LMFELogitsProcessor.PROCESSOR_NAME: logits_processor
+ }
+ "logits_post_processor_config": logits_proc_config
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return trtllm.ExecutorConfig(**kwargs)
Expand Down Expand Up @@ -603,6 +605,10 @@ class TritonPythonModel:
def get_executor_config(self, model_config):
+ tokenizer_dir = model_config['parameters']['tokenizer_dir']['string_value']
+ logits_processor = OutlinesLogitsProcessor(tokenizer_dir, AnswerFormat.model_json_schema())
+ logits_proc_config = trtllm.LogitsPostProcessorConfig()
+ logits_proc_config.processor_map = {
+ OutlinesLogitsProcessor.PROCESSOR_NAME: logits_processor
+ }
kwargs = {
"max_beam_width":
get_parameter(model_config, "max_beam_width", int),
Expand All @@ -623,9 +629,7 @@ class TritonPythonModel:
self.get_peft_cache_config(model_config),
"decoding_config":
self.get_decoding_config(model_config),
+ "logits_post_processor_map":{
+ OutlinesLogitsProcessor.PROCESSOR_NAME: logits_processor
+ }
+ "logits_post_processor_config": logits_proc_config
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
return trtllm.ExecutorConfig(**kwargs)
Expand Down
29 changes: 15 additions & 14 deletions AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from collections import defaultdict
from typing import DefaultDict, Dict, List
from typing import Any, Dict, List, Optional

import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer
from lmformatenforcer.integrations.trtllm import build_trtlmm_tokenizer_data
from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import adapt_tokenizer
from outlines.fsm.guide import Guide, RegexGuide
from outlines.models.vllm import adapt_tokenizer
from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import AutoTokenizer

Expand Down Expand Up @@ -103,6 +102,7 @@ def __call__(
logits: torch.Tensor,
ids: List[List[int]],
stream_ptr: int,
client_id: Optional[int]
):
# Create a mask with negative infinity to block all tokens initially.
mask = torch.full_like(logits, fill_value=float("-inf"), device=logits.device)
Expand All @@ -127,8 +127,8 @@ def __init__(self, tokenizer_dir, schema):
)
tokenizer = adapt_tokenizer(tokenizer)
regex_string = build_regex_from_schema(json.dumps(schema))
self.fsm = RegexGuide(regex_string, tokenizer)
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self.guide: Guide = RegexGuide.from_regex(regex_string, tokenizer)
self._guide_states: Dict[int, Any] = {}
self.mask_cache: Dict[int, torch.Tensor] = {}
# By default, TensorRT-LLM includes request query into the output.
# Outlines should only look at generated outputs, thus we'll keep
Expand All @@ -141,6 +141,7 @@ def __call__(
logits: torch.Tensor,
ids: List[List[int]],
stream_ptr: int,
client_id: Optional[int]
):
seq_id = None
# If the prefix token IDs have changed we assume that we are dealing
Expand All @@ -151,9 +152,9 @@ def __call__(
# processed
or len(ids[0][len(self._prefix) :]) == 0
):
self._fsm_state = defaultdict(int)
self._prefix = ids[0]
seq_id = hash(tuple([]))
self._guide_states = {seq_id: self.guide.initial_state}
self._prefix = ids[0]

else:
# Remove the prefix token IDs from the input token IDs,
Expand All @@ -162,14 +163,14 @@ def __call__(
last_token = ids[-1]
last_seq_id = hash(tuple(ids[:-1]))
seq_id = hash(tuple(ids))
self._fsm_state[seq_id] = self.fsm.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token
self._guide_states[seq_id] = self.guide.get_next_state(
state=self._guide_states[last_seq_id], token_id=last_token
)

state_id = self._fsm_state[seq_id]
state_id = self._guide_states[seq_id]
if state_id not in self.mask_cache:
allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
allowed_tokens = self.guide.get_next_instruction(
state=self._guide_states[seq_id]
).tokens
# Create a mask with negative infinity to block all
# tokens initially.
Expand Down