Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import logging
import os
import re
import warnings
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, Literal, Optional, Union
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union

from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import validate_repo_id
Expand Down Expand Up @@ -105,6 +106,82 @@ def validate_cuda_friendly(self, key):
raise ValueError(f"{key} must be a multiple of 64, but got {value}")


V2_QUANT_DATASET_OPTIONS = ["cnn_dailymail", "fused"]


class TrussTRTQuantizationConfigurationV2(PydanticTrTBaseModel):
"""Configuration for quantization of TRT models in the V2 engine

Args:
calib_size (int, optional): Size of calibration dataset. Defaults to 1024.
recommended to increase for production runs (e.g. 1536), or decrease e.g. to 256 for quick testing.
calib_dataset (str, optional): Huggingface dataset to use for calibration. Defaults to 'cnn_dailymail'.
uses split='train' and quantized based on 'text' column.
text_field (Union[str, List[str]], optional): The field in the dataset to use for text data. Defaults to 'text'.
template (str, optional): String template that is a python f-string to format the text field.
e.g. "Summarize the following article: {text}"
or for multiple fields: "Summarize the following article: {title}\n\n{content}"
calib_max_seq_length (int, optional): Maximum sequence length for calibration. Defaults to 2048.
calib_dataset_split (str): The dataset split to use for calibration. Defaults to 'train'.
"""

calib_size: int = 1024
calib_max_seq_length: int = 2048
calib_dataset: str = "cnn_dailymail"
text_field: Union[str, List[str]] = "text"
template: Optional[str] = None
calib_dataset_split: str = "train"

def __init__(self, **data):
super().__init__(**data)
self.validate_cuda_friendly("calib_size")
self.validate_cuda_friendly("calib_max_seq_length")
self.validate_dataset_name(
self.calib_dataset, self.text_field, self.calib_dataset_split
)

def validate_cuda_friendly(self, key):
value = getattr(self, key)
if value < 64 or value > 16384:
raise ValueError(f"{key} must be between 64 and 16384, but got {value}")
elif value % 64 != 0:
raise ValueError(f"{key} must be a multiple of 64, but got {value}")

def validate_dataset_name(self, ds_name, text_field, dataset_split):
if not isinstance(ds_name, str) or not ds_name.strip():
raise ValueError("Dataset name must be a non-empty string.")
ds_name = ds_name.strip()

if ds_name in V2_QUANT_DATASET_OPTIONS:
if text_field != "text":
raise ValueError(
f"When using preset dataset {ds_name!r}, text_field must be 'text'."
)
if dataset_split != "train":
raise ValueError(
f"When using preset dataset {ds_name!r}, dataset_split must be 'train'."
)
return

hf_full = re.compile(
r"^([A-Za-z0-9][A-Za-z0-9._-]{0,62})/([A-Za-z0-9][A-Za-z0-9._-]{0,62})(?:@([A-Za-z0-9._/-]+))?$"
)
hf_bare = re.compile(
r"^([A-Za-z0-9][A-Za-z0-9._-]{0,62})(?:@([A-Za-z0-9._/-]+))?$"
)

if hf_full.match(ds_name):
return

if hf_bare.match(ds_name):
return

raise ValueError(
f"Invalid dataset name: {ds_name!r}. Use 'owner/repo', 'repo', "
"or one of the presets: " + ", ".join(V2_QUANT_DATASET_OPTIONS)
)


class CheckpointSource(str, Enum):
HF = "HF"
GCS = "GCS"
Expand Down Expand Up @@ -587,7 +664,7 @@ def validate_inference_stack_v2(self: "TRTLLMConfigurationV2", context):
source=CheckpointSource.HF, repo="michael/any", revision=None
),
quantization_type=TrussTRTLLMQuantizationType.NO_QUANT,
quantization_config=TrussTRTQuantizationConfiguration(),
quantization_config=TrussTRTQuantizationConfigurationV2(),
).model_dump(exclude_unset=False)
for field in build_settings:
if (
Expand Down
Loading