Skip to content

Comments

Token-in vLLM endpoint#1422

Merged
mikasenghaas merged 46 commits intomainfrom
tok-in-out
Dec 17, 2025
Merged

Token-in vLLM endpoint#1422
mikasenghaas merged 46 commits intomainfrom
tok-in-out

Conversation

@mikasenghaas
Copy link
Member

@mikasenghaas mikasenghaas commented Dec 12, 2025

This PR implements a custom token-in chat completions endpoint to our vLLM inference server. The server now additionally exposes /v1/chat/completions/tokens. It is basically copy of the regular /v1/chat/completions endpoints with the difference that it takes requires that the request contains a field tokens which is the list of tokens which is used to build the engine prompt). This required two overrides:

  • ChatCompletionRequestWithTokens extends ChatCompletionRequest with a tokens field
  • OpenAIServingChatWithTokens extends OpenAIServingChat by a method create_chat_completion_with_tokens
  • The endpoint is registered at /v1/chat/completions/tokens

Note, that the inference server redundantly tokenizes the inputs at the moment. There exists a commit in this PR which skips this tokenization but I did not see this speeding up anything. I decided to remove it again because the current override is very light-weight/ unintrusive and will be esay to maintain in the future.

It also bumps verifiers to a commit including #626, which integrates the token-in endpoint into the multi-turn rollout flow.

Training Example

uv run rl @ examples/alphabet_sort/rl.toml --max-steps 50

Before

Screenshot 2025-12-12 at 5 27 40 PM

After

Screenshot 2025-12-12 at 9 11 42 PM

In wordle example, we observe significant reductions in KL mismatch

Screenshot 2025-12-15 at 4 22 21 PM

Minimal Example

Screenshot 2025-12-15 at 4 33 49 PM
uv run inference --model.name Qwen/Qwen3-4B-Instruct-2507
from typing import cast

from httpx import Client
from openai import OpenAI
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from transformers import AutoTokenizer, PreTrainedTokenizer

model_name = "Qwen/Qwen3-4B-Instruct-2507"
add_generation_prompt = True
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_name)
base_url = "http://localhost:8000"
oai_client = OpenAI(base_url=f"{base_url}/v1")
client = Client(base_url=base_url)

prompt: list[ChatCompletionMessageParam] = [{"role": "user", "content": "Hello, how are you?"}]

# Get prompt tokens from local tokenizer
local_prompt_tokens = tokenizer.apply_chat_template(cast(list[dict[str, str]], prompt), add_generation_prompt=True)
print(f"✅ Got prompt tokens from local tokenizer:\n{local_prompt_tokens}")

# Ensure server is healthy
assert client.get(f"{base_url}/health").status_code == 200
print("✅ Checked server health")

# Get prompt tokens from remote vLLM server via client
response = client.post(
    f"{base_url}/tokenize",
    json={"model": model_name, "messages": prompt, "add_generation_prompt": add_generation_prompt},
)
# print(response.json())
remote_prompt_tokens = response.json()["tokens"]
print(f"✅ Got prompt tokens from remote vLLM server:\n{remote_prompt_tokens}")

# Get prompt/completion tokens via chat completions API
extra_body = dict(return_token_ids=True, prompt_logprobs=True)
chat_completion_args = dict(
    model=model_name, messages=prompt, max_tokens=1024, temperature=0.0, logprobs=True, extra_body=extra_body
)
response = oai_client.chat.completions.create(**chat_completion_args)  # type: ignore
response_dict = response.model_dump()
chat_completion_prompt_tokens = response_dict["prompt_token_ids"]
chat_completion_tokens = response_dict["choices"][0]["token_ids"]
print(
    f"✅ Got response from /v1/chat/completions\nPrompt: {chat_completion_prompt_tokens}\nCompletion: {chat_completion_tokens}"
)


# Get prompt/completion tokens via chat completion w/ tokens API
extra_body = dict(return_token_ids=True, prompt_logprobs=True, tokens=local_prompt_tokens)
# IMPORTANT: We need to merge the extra_body into the request args (normally this happens inside the OAI client) for vLLM to have access to it
chat_completion_args.pop("extra_body")
chat_completion_with_tokens_args = {**chat_completion_args, **extra_body}
print(chat_completion_with_tokens_args)
response = oai_client.post("/chat/completions/tokens", body=chat_completion_with_tokens_args, cast_to=ChatCompletion)
response_dict = response.model_dump()
generate_prompt_tokens = response_dict["prompt_token_ids"]
generate_completion_tokens = response_dict["choices"][0]["token_ids"]
print(f"✅ Got response from /generate\nPrompt: {generate_prompt_tokens}\nCompletion: {generate_completion_tokens}")

assert local_prompt_tokens == remote_prompt_tokens == chat_completion_prompt_tokens == generate_prompt_tokens, (
    "Prompt tokens do not match"
)
print(
    "✅ Prompt and completion tokens match between local tokenizer, vLLM /tokenize, OpenAI /chat/completions, and vLLM /generate"
)

assert chat_completion_tokens == generate_completion_tokens, "Completion tokens do not match"
print("✅ Completion tokens match between OpenAI /chat/completions and vLLM /generate")

# New request
new_prompt: list[ChatCompletionMessageParam] = [{"role": "user", "content": "What is the capital of France?"}]
tokens = tokenizer.apply_chat_template(cast(list[dict[str, str]], new_prompt), add_generation_prompt=True)
response = oai_client.post(
    "/chat/completions/tokens", body={**chat_completion_with_tokens_args, "tokens": tokens}, cast_to=ChatCompletion
)
print(response.choices[0].message.content)

GitHub Issue: #1421
Linear Issue: Resolves PRIMERL-243


Note

Adds a vLLM /v1/chat/completions/tokens endpoint for token-in requests and auto-sets api_server_count=1 when LoRA is enabled; enables interleaved rollouts to use token prompts and bumps verifiers.

  • Inference Server (vLLM):
    • New endpoint: Serve token-in chat completions at /v1/chat/completions/tokens in src/prime_rl/inference/vllm/server.py using OpenAIServingChatWithTokens from src/prime_rl/inference/vllm/serving_chat_with_tokens.py.
    • Registers route with validation/streaming, loads chat template, and integrates with app state.
  • Config/Runtime:
    • src/prime_rl/inference/config.py: Auto-set api_server_count=1 when enable_lora is true; otherwise ensure >= parallel.dp.
    • CHANGELOG.md: Document LoRA API server limitation.
  • Orchestrator:
    • src/prime_rl/orchestrator/orchestrator.py: For trajectory_strategy="interleaved", enable token prompts via env.set_interleaved_rollouts(True).
  • Docs:
    • docs/trajectories.md: Update trajectory guidance; remove retokenization section, clarify chat template behavior.
  • Examples/Configs:
    • configs/alphabet_sort/rl.toml: Add W&B block; update env id; minor structure tweaks.
    • examples/wordle/rl.toml: Set inference.parallel.dp = 6.
  • Dependencies:
    • pyproject.toml/uv.lock: Bump verifiers to ca75d04 (v0.1.8.post2).

Written by Cursor Bugbot for commit beb2c29. This will update automatically on new commits. Configure here.

@mikasenghaas mikasenghaas force-pushed the tok-in-out branch 4 times, most recently from b5c33f2 to c56da9d Compare December 15, 2025 07:54
@mikasenghaas mikasenghaas requested a review from samsja December 15, 2025 16:37
@mikasenghaas mikasenghaas marked this pull request as ready for review December 16, 2025 21:35
Comment on lines +124 to +131
if engine_prompts[0]["prompt_token_ids"] != request.tokens:
logger.warning(
"Prompt tokens provided in request do not match the engine prompt tokens. This may happen due to retokenization discrepancies in multi-turn conversations. Since you are using the /v1/chat/completions/tokens endpoint, we assume you want this behavior and use the provided prompt tokens. If this is undesired, use the standard /v1/chat/completions endpoint instead."
)
logger.debug(f"engine_prompt_tokens:\n{engine_prompts[0]['prompt_token_ids']}")
logger.debug(f"request_tokens:\n{request.tokens}")

engine_prompts[0]["prompt_token_ids"] = request.tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if im understanding correctly, we process the chat completion request just to throw it away at the end and replace it with the passed in token ids?

what do we think about just directly using the passed token ids and ignoring the processing entirely? would make it simpler and we dont have to copy over the chat completion code every time we upgrade vllm.

Copy link
Member Author

@mikasenghaas mikasenghaas Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i acc had this implemented at commit 43c6a2b but decided against doing it because:

  • it wasn't much faster
  • i cannot print the warning logs (which are actually quite a nice sanity check, e.g. if this log is shown on every req, it probably means smth went wront in the pre-tokenization; we should get this log sometimes but not all the times
  • i don't think we can ignore the processing entirely, e.g. the handler depends on having access to conversation which requires applying the chat template anyways, so intercepting the engine prompt after all processing has been done seemed like the easiest fix that would handle all cases (e.g. don't have to handle partial processing, worry abt harmony code path, etc.)
    if you have an idea on how to maybe no override the whole method tho but have more of a "monkey-patch" type behavior that'd be ideal. but also i feel like this might change in the new vesrion anyways bc i think they quite heavily refactored the api server, so thinks are looking quite diff either way and will likely have to figure it out in this new world anyways

Copy link
Member

@samsja samsja left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

* add math python example

* use ac

* update rl instructinos

* use 8 gpus and no ac

* use offline filtering and val split because inference a bit too quick

* 12k seq len, 2k max tokens, 300 steps

* use hendrycks math with 512 tokens/turn

* zero completion on error

* log error rate

* update math python to use vf version

* fix arg

* fix completion mask type

* fix tests

* handle empty trajectories case

* add changelog

* bump vf

* log err rate and err distribution

* fix instance

* bump vf

* fix dropna only on err col

* also mask out first turn in interleaved mode

* handle skipped rollouts
@mikasenghaas mikasenghaas merged commit 8dccae6 into main Dec 17, 2025
6 checks passed
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Error response not wrapped in JSONResponse with status code

When handler is None, the code returns base(raw_request).create_error_response(...) directly, which returns an ErrorResponse Pydantic model. This is inconsistent with lines 191-192 which properly wrap ErrorResponse in JSONResponse(content=generator.model_dump(), status_code=generator.error.code). The direct return of ErrorResponse causes FastAPI to serialize it with a 200 OK status code instead of an appropriate error status code, making the error response appear successful to clients.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants