Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2ca61f8
duplicate chat completinos endpoint into /generate
mikasenghaas Dec 12, 2025
6eeb4e1
serve chat with token in functionality
mikasenghaas Dec 12, 2025
74821ea
use field to avoid misleading warning
mikasenghaas Dec 12, 2025
02485a1
nicer error msg
mikasenghaas Dec 12, 2025
2d0bcab
lock feature branch
mikasenghaas Dec 12, 2025
429a462
make use tokens prompt configurable
mikasenghaas Dec 12, 2025
31f111b
use setter and print info
mikasenghaas Dec 12, 2025
2666f41
bump
mikasenghaas Dec 12, 2025
dc45983
include inference
mikasenghaas Dec 12, 2025
08b13d8
do not print warning log (logs all the time)
mikasenghaas Dec 12, 2025
eda2348
bump
mikasenghaas Dec 12, 2025
d2f530d
bump + bring back warning log
mikasenghaas Dec 13, 2025
5a41f94
bump vf
mikasenghaas Dec 14, 2025
162f05f
bump vf
mikasenghaas Dec 14, 2025
84c00cb
use dp=6 in wordle example
mikasenghaas Dec 14, 2025
9a0fc7d
no deepcopy and no warning
mikasenghaas Dec 14, 2025
8fdd5e7
do not tokenize on the server
mikasenghaas Dec 14, 2025
2b4052b
add field names so that tokens is cached and no warning of unrecogniz…
mikasenghaas Dec 14, 2025
7725066
bump vf
mikasenghaas Dec 14, 2025
e815f44
auto install
mikasenghaas Dec 14, 2025
eb99034
bump vf
mikasenghaas Dec 14, 2025
37035e8
bump vf + set vllm tokenize method
mikasenghaas Dec 14, 2025
a339566
skip applying chat template
mikasenghaas Dec 15, 2025
bcd1ac4
Revert "skip applying chat template"
mikasenghaas Dec 15, 2025
035bcda
Revert "do not tokenize on the server"
mikasenghaas Dec 15, 2025
2c0bd59
bring back log
mikasenghaas Dec 15, 2025
675a772
use route /v1/chat/completions/tokens
mikasenghaas Dec 15, 2025
dbebae9
fix log
mikasenghaas Dec 15, 2025
44c493d
bump vf and make everything configurable
mikasenghaas Dec 15, 2025
8ef375b
bump and more informative log
mikasenghaas Dec 15, 2025
ddfb878
bump and make non-exact tokenization default
mikasenghaas Dec 15, 2025
09355e4
use token prompts by default
mikasenghaas Dec 15, 2025
21b85e6
remove retokenization issue from docs
mikasenghaas Dec 15, 2025
8f7a090
rename class
mikasenghaas Dec 15, 2025
8e0c984
bump vf
mikasenghaas Dec 15, 2025
4ce7321
fix auto asc setup for lora
mikasenghaas Dec 15, 2025
17dd9ad
bump vf
mikasenghaas Dec 15, 2025
6104f53
bump vf
mikasenghaas Dec 15, 2025
3254f53
bump vf
mikasenghaas Dec 16, 2025
6a50b19
bring back setter
mikasenghaas Dec 16, 2025
4961e4d
bump vf
mikasenghaas Dec 16, 2025
e627788
bump vf to latest prime-rl
mikasenghaas Dec 16, 2025
c1aa492
Explicitly catch env errors (#1416)
mikasenghaas Dec 17, 2025
8ccf327
Revert "Explicitly catch env errors (#1416)"
mikasenghaas Dec 17, 2025
14c77d3
bump vf
mikasenghaas Dec 17, 2025
beb2c29
add asc comment
mikasenghaas Dec 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Documenting changes which affect configuration usage patterns (added/moved/removed/renamed fields, notable logic changes).

- **`model.lora`**: Moved from `model.experimental.lora` to `model.lora` (no longer experimental) (#1440, 2025-12-16)
- Auto-set `api_server_count=1` on inference when LoRA is enabled, because vLLM doesn't support hotloading for multiple API servers (#1422, 2025-12-17)
15 changes: 10 additions & 5 deletions configs/alphabet_sort/rl.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
inference_gpu_ids = [0]
trainer_gpu_ids = [1]

max_steps = 100

[wandb]
project = "alphabet-sort-debug"
name = "alphabet-sort"


[model]
name = "Qwen/Qwen3-4B-Instruct-2507"

Expand All @@ -15,7 +17,10 @@ seq_len = 2048
max_tokens = 512

[[orchestrator.env]]
id = "alphabet-sort"
id = "primeintellect/alphabet-sort"
name = "alphabet-sort"
args = { min_turns = 2, max_turns = 2 }

[trainer]
[trainer]

[inference]
16 changes: 0 additions & 16 deletions docs/trajectories.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,6 @@ We call this the "exact prefix" invariant. For example, at turn 2, the LLM shoul
- If we add A1', the logprobs from turn 1 might be off because the inference LLM produced A1 but the trainer LLM is computing logprobs for A1'
- If we add A1, the logprobs from turn 2 might be off because the inference LLM is attending to A1' but the trainer LLM is attending to A1.

While it would seem that this invariant is easy to enforce, there are two surprisingly common violations that can occur:
- **Retokenization**: Modern tokenizers are not symmetric, and so retokenizing A1 to prepare the prompt at turn 2 may produce A1' != A1
- **Arbitrary Chat Templates**: The chat template may add or remove tokens across turns

### Retokenization

There are two main reasons why retokenization is a problem:
- `verifiers` uses an OAI client for collecting a trajectory, which only supports text-in prompts, which implies that at any turn $t>1$, we have to retokenize the assistant messages from turns $1,\dots,t-1$.
- Modern tokenizers are not symmetric, and so retokenizing the assistant messages from turns $1,\dots,t-1$ may produce different tokens than the original decoded tokens by the LLM.

For example, it may be that the LLM, during decoding, produced the number `11` as two separate tokens, corresponding to `1` and `1`, but when retokenizing `11` becomes a single token. We have violated the exact prefix invariant: It is unclear whether to add `1` and `1` or `11` to the interleaved rollout:
- If we choose to add `1` and `1`, the logprobs from turn 2 might be off because the attention of the trainer and inference LLM is different.
- If we choose to add `11`, the logprobs from turn 1 might be off because `11` has a different likelihood than `1` and `1`.

In moderately complex multi-turn environments (e.g. `wordle` or `wiki-search`) we have found that our Icepop-style double-sided masking is able to mitigate such discrepancies, but it is unclear whether it remains robust for large-scale agentic training in environments that are many hundreds of turns long. A simple solution to this problem is given by allowing for token-in requests, but this is neither standard OAI nor vLLM spec, which is why it is not easily supported yet.

### Arbitrary Chat Templates

There exist chat templates which add, modify, or remove tokens across turns. One good example, is the chat template of the Qwen3-series of models, which strips thinking across user turns.
Expand Down
3 changes: 1 addition & 2 deletions examples/wordle/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ max_tokens = 1024
[trainer] # Default trainer config

[inference.parallel]
dp = 3
tp = 2
dp = 6
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ prerelease = "allow"
[tool.uv.sources]
torch = { index = "pytorch-cu128" }
reverse-text = { index = "primeintellect" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "ca6aca7" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "ca75d04" }
dion = { git = "https://github.com/samsja/dion.git", rev = "main" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }

Expand Down
11 changes: 9 additions & 2 deletions src/prime_rl/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,17 @@ def round_up_max_lora_rank(self):
return self

@model_validator(mode="after")
def ensure_api_server_count_is_at_least_dp_size(self):
"""Ensures that we have at least as many API servers as data parallel size."""
def auto_setup_api_server_count(self):
"""
Ensures that we have at least as many API servers as data parallel
size. Unless LoRA is enabled, in which case only one API server is
supported (vLLM limitation).
"""
if self.api_server_count < self.parallel.dp:
self.api_server_count = self.parallel.dp

if self.enable_lora:
self.api_server_count = 1 # LoRA requires only one API server
return self

def to_vllm(self) -> Namespace:
Expand Down
96 changes: 93 additions & 3 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from argparse import Namespace
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Any, Optional

from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse
from vllm.entrypoints.utils import load_aware_call, with_cancellation

from prime_rl.inference.patches import monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode
from prime_rl.inference.vllm.serving_chat_with_tokens import (
ChatCompletionRequestWithTokens,
OpenAIServingChatWithTokens,
)

# Monkeypatch PrometheusStatLogger to avoid NotImplementedError for LoRA in DP mode
monkey_patch_prometheus_stat_logger_for_lora_in_dp_mode()
Expand All @@ -13,17 +24,20 @@

import uvloop
import vllm.envs as envs
from fastapi import Request
from vllm.config import LogprobsMode
from fastapi import Depends, HTTPException, Request
from vllm.config import LogprobsMode, VllmConfig
from starlette.datastructures import State
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
base,
build_app,
build_async_engine_client_from_engine_args,
init_app_state,
load_log_config,
maybe_register_tokenizer_info_endpoint,
validate_json_request,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -60,6 +74,49 @@ async def custom_build_async_engine_client(
yield engine


async def custom_init_app_state(engine_client: EngineClient, vllm_config: VllmConfig, state: State, args: Namespace):
await init_app_state(engine_client, vllm_config, state, args)

# Repeat from init_app_state to have
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
else:
request_logger = None

model_config = vllm_config.model_config

if envs.VLLM_USE_V1:
supported_tasks = await engine_client.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks

resolved_chat_template = load_chat_template(args.chat_template)

# Also serve OAI chat completion tokens
state.openai_serving_chat_with_tokens = (
OpenAIServingChatWithTokens(
engine_client,
model_config,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
)


# Copied from vllm/entrypoints/openai/api_server.py
# Only difference is that we inject custom routes and build async engine client differently
async def custom_run_server_worker(listen_address, sock, args, client_config=None, **uvicorn_kwargs) -> None:
Expand Down Expand Up @@ -106,8 +163,41 @@ async def _init_broadcaster(request: Request):
)
return {"status": "ok"}

def chat_with_tokens(request: Request) -> Optional[OpenAIServingChatWithTokens]:
return request.app.state.openai_serving_chat_with_tokens

@app.post(
"/v1/chat/completions/tokens",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def _chat_with_tokens(request: ChatCompletionRequestWithTokens, raw_request: Request):
handler = chat_with_tokens(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Error response not wrapped in JSONResponse with status code

When handler is None, the code returns base(raw_request).create_error_response(...) directly, which returns an ErrorResponse Pydantic model. This is inconsistent with lines 191-192 which properly wrap ErrorResponse in JSONResponse(content=generator.model_dump(), status_code=generator.error.code). The direct return of ErrorResponse causes FastAPI to serialize it with a 200 OK status code instead of an appropriate error status code, making the error response appear successful to clients.

Fix in Cursor Fix in Web

try:
generator = await handler.create_chat_completion_with_tokens(request, raw_request)
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), status_code=generator.error.code)

elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())

return StreamingResponse(content=generator, media_type="text/event-stream")

vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
await custom_init_app_state(engine_client, vllm_config, app.state, args)

# This hack allows us to update lora adapters in-place by skipping the check for already loaded adapters.
async def do_nothing(*args, **kwargs):
Expand Down
Loading