Skip to content

Commit

Permalink
Hotfixing auto length (warmup max_s was wrong). (#2716)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Nov 4, 2024
1 parent 08c4184 commit a5593ba
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 11 deletions.
7 changes: 0 additions & 7 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> {
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
let max = max_position_embeddings;
if args.max_input_tokens.is_none()
&& args.max_total_tokens.is_none()
&& args.max_batch_prefill_tokens.is_none()
{
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
}
max_default
} else {
max_position_embeddings
Expand Down
4 changes: 1 addition & 3 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,6 @@ def warmup(
self.kv_cache_dtype,
self.device,
)
max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE
batch_num_blocks = batch.num_blocks

if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
Expand Down Expand Up @@ -1651,7 +1649,7 @@ def warmup(
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)
except torch.cuda.OutOfMemoryError:
logger.exception("Decode cuda graph warmup failed")
else:
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/metadata_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def block_tables_to_ragged(
cache_lengths: List[int],
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
max_current_length: int
max_current_length: int,
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths)
Expand Down

0 comments on commit a5593ba

Please sign in to comment.