Skip to content

Upstream Sync #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
999328b
[Model] Add GraniteMoeHybrid 4.0 model (#17497)
s3woz May 6, 2025
edbf2d6
[easy] Fix logspam on PiecewiseBackend errors (#17138)
zou3519 May 6, 2025
dc47ba3
[Bugfix] Fixed prompt length for random dataset (#17408)
Xarbirus May 6, 2025
63ced7b
[Doc] Update notes for H2O-VL and Gemma3 (#17219)
DarkLight1337 May 6, 2025
6eae345
[Misc] Fix ScalarType float4 naming (#17690)
LucasWilkinson May 6, 2025
05e1f96
Fix `dockerfilegraph` pre-commit hook (#17698)
hmellor May 6, 2025
f9bc5a0
[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
MengqingCao May 6, 2025
d419aa5
[V1] Enable TPU V1 backend by default (#17673)
mgoin May 6, 2025
a6fed02
[V1][PP] Support PP for MultiprocExecutor (#14219)
bigPYJ1151 May 6, 2025
cba31c4
[v1] AttentionMetadata for each layer (#17394)
heheda12345 May 6, 2025
175bda6
[Feat] Add deprecated=True to CLI args (#17426)
aarnphm May 6, 2025
0d11546
[Docs] Use gh-file to add links to tool_calling.md (#17709)
windsonsea May 6, 2025
aabcd2c
[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCac…
heheda12345 May 6, 2025
7525d5f
[doc] Add RAG Integration example (#17692)
reidliu41 May 6, 2025
5b8c390
[Bugfix] Fix modality limits in vision language example (#17721)
DarkLight1337 May 6, 2025
6115b11
Make right sidebar more readable in "Supported Models" (#17723)
hmellor May 6, 2025
621ca2c
[TPU] Increase block size and reset block shapes (#16458)
bythew3i May 6, 2025
d456aea
[Misc] Add Next Edit Prediction (NEP) datasets support in `benchmark_…
dtransposed May 6, 2025
de906b9
[Bugfix] Fix for the condition to accept empty encoder inputs for mll…
gshtras May 6, 2025
2f925e5
[Kernel] Unified Triton kernel that doesn't distinguish between prefi…
tdoublep May 6, 2025
42b869e
updated
robertgshaw2-redhat May 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ repos:
name: Update Dockerfile dependency graph
entry: tools/update-dockerfile-graph.sh
language: script
files: ^docker/Dockerfile$
pass_filenames: false
# Keep `suggestion` last
- id: suggestion
name: Suggestion
Expand Down
105 changes: 103 additions & 2 deletions benchmarks/benchmark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,15 @@ def sample(
)

vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens

prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])

# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))

Expand All @@ -344,6 +346,17 @@ def sample(
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
Expand Down Expand Up @@ -874,6 +887,94 @@ def sample(self,
return sampled_requests


# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------


zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.

### User Edits:

{}

### User Excerpt:

{}

### Response:

""" # noqa: E501


def _format_zeta_prompt(
sample: dict,
original_start_marker: str = "<|editable_region_start|>") -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.

This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
further extended to support more NEP datasets.

Args:
sample: The dataset sample containing events,
inputs, and outputs.
original_start_marker: The marker indicating the
start of the editable region. Defaults to
"<|editable_region_start|>".

Returns:
A dictionary with the formatted prompts and expected outputs.
"""
events = sample["events"]
input = sample["input"]
output = sample["output"]
prompt = zeta_prompt.format(events, input)

# following the original implementation, extract the focused region
# from the raw output
output_start_index = output.find(original_start_marker)
output_focused_region = output[output_start_index:]
expected_output = output_focused_region

return {"prompt": prompt, "expected_output": expected_output}


class NextEditPredictionDataset(HuggingFaceDataset):
"""
Dataset class for processing a Next Edit Prediction dataset.
"""

SUPPORTED_DATASET_PATHS = {
"zed-industries/zeta",
}
MAPPING_PROMPT_FUNCS = {
"zed-industries/zeta": _format_zeta_prompt,
}

def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
for sample in self.data:
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids),
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
return samples


# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
Expand Down
8 changes: 6 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, MTBenchDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
NextEditPredictionDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json

MILLISECONDS_TO_SECONDS_CONVERSION = 1000
Expand Down Expand Up @@ -603,6 +604,9 @@ def main(args: argparse.Namespace):
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_class = AIMODataset
args.hf_split = "train"
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
dataset_class = NextEditPredictionDataset
args.hf_split = "train"
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ASRDataset
args.hf_split = "train"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser

FP8_DTYPE = current_platform.fp8_dtype()
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Optional, Union

import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn

from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton


class HuggingFaceRMSNorm(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor

# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton


# Copied from
Expand Down
Binary file modified docs/source/assets/contributing/dockerfile-stages-dependency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/deployment/frameworks/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ helm
lws
modal
open-webui
retrieval_augmented_generation
skypilot
streamlit
triton
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
(deployment-retrieval-augmented-generation)=

# Retrieval-Augmented Generation

[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources.

Here are the integrations:
- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus)
- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus)

## vLLM + langchain

### Prerequisites

- Setup vLLM and langchain environment

```console
pip install -U vllm \
langchain_milvus langchain_openai \
langchain_community beautifulsoup4 \
langchain-text-splitters
```

### Deploy

- Start the vLLM server with the supported embedding model, e.g.

```console
# Start embedding service (port 8000)
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
```

- Start the vLLM server with the supported chat completion model, e.g.

```console
# Start chat service (port 8001)
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
```

- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_langchain.py>

- Run the script

```python
python retrieval_augmented_generation_with_langchain.py
```

## vLLM + llamaindex

### Prerequisites

- Setup vLLM and llamaindex environment

```console
pip install vllm \
llama-index llama-index-readers-web \
llama-index-llms-openai-like \
llama-index-embeddings-openai-like \
llama-index-vector-stores-milvus \
```

### Deploy

- Start the vLLM server with the supported embedding model, e.g.

```console
# Start embedding service (port 8000)
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
```

- Start the vLLM server with the supported chat completion model, e.g.

```console
# Start chat service (port 8001)
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
```

- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_llamaindex.py>

- Run the script

```python
python retrieval_augmented_generation_with_llamaindex.py
```
28 changes: 14 additions & 14 deletions docs/source/features/tool_calling.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ Known issues:
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:

* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
* <gh-file:examples/tool_chat_template_mistral.jinja> - this is the "official" Mistral chat template, but tweaked so that
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.

Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
Expand All @@ -170,15 +170,15 @@ Known issues:

VLLM provides two JSON based chat templates for Llama 3.1 and 3.2:

* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1
* <gh-file:examples/tool_chat_template_llama3.1_json.jinja> - this is the "official" chat template for the Llama 3.1
models, but tweaked so that it works better with vLLM.
* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for
* <gh-file:examples/tool_chat_template_llama3.2_json.jinja> - this extends upon the Llama 3.1 chat template by adding support for
images.

Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}`

VLLM also provides a JSON based chat template for Llama 4:
* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4
* <gh-file:examples/tool_chat_template_llama4_json.jinja> - this is based on the "official" chat template for the Llama 4
models, but tweaked so that it works better with vLLM.

For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`.
Expand All @@ -191,7 +191,7 @@ Supported models:

Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`

`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.
<gh-file:examples/tool_chat_template_granite.jinja>: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.

* `ibm-granite/granite-3.1-8b-instruct`

Expand All @@ -203,7 +203,7 @@ The chat template from Huggingface can be used directly. Parallel function calls

Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`

`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
<gh-file:examples/tool_chat_template_granite_20b_fc.jinja>: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.

### InternLM Models (`internlm`)

Expand Down Expand Up @@ -253,12 +253,12 @@ Limitations:

Example supported models:

* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
* `meta-llama/Llama-3.2-1B-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama3.2_pythonic.jinja>)
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama3.2_pythonic.jinja>)
* `Team-ACE/ToolACE-8B` (use with <gh-file:examples/tool_chat_template_toolace.jinja>)
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with <gh-file:examples/tool_chat_template_toolace.jinja>)
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama4_pythonic.jinja>)
* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama4_pythonic.jinja>)

Flags: `--tool-call-parser pythonic --chat-template {see_above}`

Expand All @@ -270,7 +270,7 @@ Llama's smaller models frequently fail to emit tool calls in the correct format.

## How to write a tool parser plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in <gh-file:vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py>.

Here is a summary of a plugin file:

Expand Down
Loading