Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
0406268
skeleton of token in multi-turn conversation
mikasenghaas Dec 12, 2025
b584f0e
fix url parsing
mikasenghaas Dec 12, 2025
bf88a32
allow configuring in eval entrypoint
mikasenghaas Dec 12, 2025
4a2c8a8
raise for status
mikasenghaas Dec 12, 2025
046e433
set sampling args and print warning
mikasenghaas Dec 12, 2025
c0e680f
correctly process tools
mikasenghaas Dec 12, 2025
16d68c5
fix eval cli test
mikasenghaas Dec 12, 2025
1d004fe
use setter
mikasenghaas Dec 12, 2025
6389aff
remove prints
mikasenghaas Dec 12, 2025
eb4d702
create a oai client copy with diff base url and use post from there t…
mikasenghaas Dec 12, 2025
866384a
also use oai post routes for tokenize
mikasenghaas Dec 12, 2025
94e394b
use incremental tokenization trick again
mikasenghaas Dec 13, 2025
b404171
suffix trick
mikasenghaas Dec 13, 2025
56efbfb
only call env_response once
mikasenghaas Dec 14, 2025
c9c28d0
only check suffix if and eom token present
mikasenghaas Dec 14, 2025
1c32891
parallelize tokenization calls
mikasenghaas Dec 14, 2025
f16797c
avoid redundant concat
mikasenghaas Dec 14, 2025
d8ccee2
fix typo
mikasenghaas Dec 14, 2025
96b9a98
move use token prompts into state
mikasenghaas Dec 14, 2025
1276711
add use token prompts to metadata
mikasenghaas Dec 14, 2025
1e94571
removed docstring
mikasenghaas Dec 14, 2025
6ecbae5
merge logic into get_model_response to avoid redundant code
mikasenghaas Dec 14, 2025
0b3fca6
abstract tokenize_vllm and do not use token prompts on initial turn
mikasenghaas Dec 14, 2025
1e0bf29
allow local tokenization to save http overhead
mikasenghaas Dec 14, 2025
2eb4518
rename to utp
mikasenghaas Dec 14, 2025
3d12cf4
better best case complexity for find_lst_index
mikasenghaas Dec 14, 2025
73bf1d9
configure exact tokenization
mikasenghaas Dec 14, 2025
b427069
fix caches + one run local tokenizer in process pool
mikasenghaas Dec 14, 2025
5631365
add debug logs
mikasenghaas Dec 14, 2025
d71199b
generate client
mikasenghaas Dec 14, 2025
fbd9396
reverse conditional
mikasenghaas Dec 14, 2025
79f397d
use thread pool
mikasenghaas Dec 14, 2025
4e847b8
fix caching bug
mikasenghaas Dec 14, 2025
22daf23
cleanup
mikasenghaas Dec 14, 2025
7b25189
more workers in threadpool
mikasenghaas Dec 14, 2025
41c7a29
do exact tokenization
mikasenghaas Dec 14, 2025
e0da0fa
bring back setter
mikasenghaas Dec 14, 2025
bf19bd3
match signature
mikasenghaas Dec 14, 2025
f6ed5e2
setup state from class attr
mikasenghaas Dec 14, 2025
8c6c4e2
also set tokenize method
mikasenghaas Dec 14, 2025
ee4e858
avoid mutation
mikasenghaas Dec 14, 2025
1a5d193
larger default thread pool
mikasenghaas Dec 15, 2025
cb33caf
make exact tokenization configurable
mikasenghaas Dec 15, 2025
9df126d
fix passing exact tokenization
mikasenghaas Dec 15, 2025
e4f4636
use client directly on /v1/chat/completions/tokens route
mikasenghaas Dec 15, 2025
dd504cc
add timing [revert later]
mikasenghaas Dec 15, 2025
08e8f07
Revert "add timing [revert later]"
mikasenghaas Dec 15, 2025
c942c44
add exact tokenization in rollout
mikasenghaas Dec 15, 2025
d931786
make tokens prompt arg none by default
mikasenghaas Dec 15, 2025
67bd8cc
compute + cache suffix ids in non-exact tokenization and make it the …
mikasenghaas Dec 15, 2025
7edddd9
fix tests
mikasenghaas Dec 15, 2025
13be6c9
fix ty
mikasenghaas Dec 15, 2025
329a7ec
fix caching edge case with truncated turns
mikasenghaas Dec 15, 2025
237fb8b
correctly support completions tokenization req
mikasenghaas Dec 15, 2025
ca1dc2d
shorter warning
mikasenghaas Dec 15, 2025
71bb1a3
fix ty
mikasenghaas Dec 15, 2025
bc113be
only support setting token prompt args via class attrs/ setters
mikasenghaas Dec 15, 2025
fd3b726
do not tokenize with tools in non exact mode
mikasenghaas Dec 15, 2025
3eb3643
fix overlap
mikasenghaas Dec 15, 2025
00e39e7
fix adding suffix
mikasenghaas Dec 15, 2025
5d1c7a7
remove tokenize_method
mikasenghaas Dec 16, 2025
9b4d2d6
move tokenize method
mikasenghaas Dec 16, 2025
d5317e5
deprecate exact tokenization
mikasenghaas Dec 16, 2025
2864e8e
move building prompt_ids into get_model_respsonse
mikasenghaas Dec 16, 2025
b1125d1
using generic extra env kwargs setter
mikasenghaas Dec 16, 2025
d7a0a7a
fix tests
mikasenghaas Dec 16, 2025
405ce8c
revert env group changes
mikasenghaas Dec 16, 2025
ace0ec1
revert moving line
mikasenghaas Dec 16, 2025
d416a6e
remove unused logger
mikasenghaas Dec 16, 2025
8a1c2a5
revert changing state
mikasenghaas Dec 16, 2025
d4b44b6
abstract overlong prompt handling into decorator
mikasenghaas Dec 16, 2025
3926b10
rename to get_model_response_with_messages
mikasenghaas Dec 16, 2025
1b1e0c1
fix tokens_client caching
mikasenghaas Dec 16, 2025
c596bb9
fix typo
mikasenghaas Dec 16, 2025
129fb43
more typo
mikasenghaas Dec 16, 2025
451d1a2
more accurate comment
mikasenghaas Dec 16, 2025
3759f7f
allow setter
mikasenghaas Dec 16, 2025
e36460b
fix url parsing edge case
mikasenghaas Dec 16, 2025
67d98b1
support generic setters
mikasenghaas Dec 16, 2025
3233ed4
move method
mikasenghaas Dec 16, 2025
ad15055
more readable find_last_index
mikasenghaas Dec 16, 2025
f4c44e5
fix
mikasenghaas Dec 16, 2025
4185685
remove the is not recommended for general use
mikasenghaas Dec 16, 2025
75fa695
fix find last index again
mikasenghaas Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _run_cli(monkeypatch, overrides):
"save_every": -1,
"save_to_hf_hub": False,
"hf_hub_dataset_name": "",
"extra_env_kwargs": {},
}
base_args.update(overrides)
args_namespace = SimpleNamespace(**base_args)
Expand Down
6 changes: 6 additions & 0 deletions verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,9 @@ def set_max_seq_len(self, max_seq_len: int | None) -> None:
self.max_seq_len = max_seq_len
for env in self.envs:
env.set_max_seq_len(max_seq_len)

def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
"""Set the interleaved rollouts flag for this environment group and all sub-environments."""
self.interleaved_rollouts = interleaved_rollouts
for env in self.envs:
env.set_interleaved_rollouts(interleaved_rollouts)
243 changes: 194 additions & 49 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import atexit
import functools
import inspect
import json
import logging
Expand All @@ -24,6 +25,7 @@

from datasets import Dataset
from openai import AsyncOpenAI, BadRequestError, OpenAI
from openai.types.chat import ChatCompletion

import verifiers as vf
from verifiers.parsers.parser import Parser
Expand All @@ -48,6 +50,10 @@
strip_nones_from_content,
)
from verifiers.utils.path_utils import get_results_path
from verifiers.utils.token_utils import (
get_prompt_ids,
prepare_sampling_args_for_token_prompts,
)

if TYPE_CHECKING:
pass
Expand All @@ -74,6 +80,7 @@ def __init__(
env_args: dict | None = None,
map_kwargs: dict = {},
max_seq_len: int | None = None,
interleaved_rollouts: bool = False,
**kwargs,
):
self.logger = logging.getLogger(f"verifiers.envs.{self.__class__.__name__}")
Expand All @@ -91,6 +98,9 @@ def __init__(
self.env_id = env_id or ""
self.env_args = env_args or {}
self.max_seq_len = max_seq_len

self.set_interleaved_rollouts(interleaved_rollouts)

if self.message_type == "chat":
if dataset is not None:
self.dataset = self.format_dataset(
Expand Down Expand Up @@ -326,33 +336,98 @@ async def get_model_response(

Convenience function for wrapping (chat, completion) API calls.
Returns special error messages for context length issues.

If interleaved_rollouts is set, the model response is obtained by
calling a custom token-in endpoint. Note, that this only works if the
inference server implements this endpoint. Currently, this is a
hand-crafted feature for PRIME-RL's vLLM server extension, and is not
recommended for general use.
"""
# resolve optional argument, fallback to state or class defaults
client = client or state["client"]
model = model or state["model"]
oai_tools = oai_tools or state["oai_tools"]
sampling_args = cast(
SamplingArgs, sampling_args or state["sampling_args"] or {}
)
message_type = message_type or self.message_type
assert model is not None
assert client is not None

# normalize sampling args:
# - if max_tokens is provided for chat, rename to max_completion_tokens
# - drop any None-valued entries to avoid sending to the client
if "max_tokens" in sampling_args:
if sampling_args["max_tokens"] is None:
sampling_args.pop("max_tokens")
elif message_type == "chat":
sampling_args["max_completion_tokens"] = sampling_args.pop("max_tokens")
if (
"max_completion_tokens" in sampling_args
and sampling_args["max_completion_tokens"] is None
):
sampling_args.pop("max_completion_tokens")
clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None}
try:

def resolve_optional_args(
client: AsyncOpenAI | None,
model: str | None,
oai_tools: list[ChatCompletionToolParam] | None,
sampling_args: SamplingArgs | None,
message_type: MessageType | None,
) -> tuple[
AsyncOpenAI,
str,
list[ChatCompletionToolParam] | None,
SamplingArgs,
MessageType,
]:
"""Resolve optional arguments, fallback to state or class defaults."""
client = client or state["client"]
model = model or state["model"]
assert client is not None and model is not None
oai_tools = oai_tools or state["oai_tools"]
sampling_args = cast(
SamplingArgs, sampling_args or state["sampling_args"] or {}
)
message_type = message_type or self.message_type
return client, model, oai_tools, sampling_args, message_type

def normalize_sampling_args(sampling_args: SamplingArgs) -> SamplingArgs:
"""
Normalize sampling arguments. Mainly does 2 things:
- if max_tokens is provided for chat, rename to max_completion_tokens
- drop any None-valued entries to avoid sending to the client
"""
if "max_tokens" in sampling_args:
if sampling_args["max_tokens"] is None:
sampling_args.pop("max_tokens")
elif message_type == "chat":
sampling_args["max_completion_tokens"] = sampling_args.pop(
"max_tokens"
)
if (
"max_completion_tokens" in sampling_args
and sampling_args["max_completion_tokens"] is None
):
sampling_args.pop("max_completion_tokens")
return {k: v for k, v in sampling_args.items() if v is not None}

def handle_overlong_prompt(func):
"""Decorator to handle overlong prompt errors from the model API."""

@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
# in case of making a request with an overlong prompt, e.g
# we raise a special overlong prompt error
if isinstance(e, BadRequestError):
error_text = e.response.text.lower()
context_length_phrases = [
"this model's maximum context length is",
"is longer than the model's context length",
"exceeds the model's context length",
"exceed the configured limit",
"exceeds the configured limit",
"exceeded model",
]
if any(
phrase in error_text for phrase in context_length_phrases
):
self.logger.debug("Caught overlong prompt.")
raise vf.OverlongPromptError(e)
# in all other case we raise a generic model error
raise vf.ModelError(e)

return wrapper

@handle_overlong_prompt
async def get_model_response_with_messages(
client: AsyncOpenAI,
model: str,
prompt: Messages,
oai_tools: list[ChatCompletionToolParam] | None,
sampling_args: SamplingArgs,
message_type: MessageType,
) -> ModelResponse:
"""Convenience function for wrapping (chat, completion) API calls."""
if message_type == "chat":
assert isinstance(prompt, list)
prompt = strip_nones_from_content(prompt)
Expand All @@ -372,9 +447,9 @@ async def get_model_response(
break
except Exception:
has_audio = False
if has_audio and "modalities" not in clean_sampling_args:
clean_sampling_args = {
**clean_sampling_args,
if has_audio and "modalities" not in sampling_args:
sampling_args = {
**sampling_args,
"modalities": ["text"],
}

Expand All @@ -383,13 +458,13 @@ async def get_model_response(
model=model,
messages=prompt,
tools=oai_tools,
**clean_sampling_args,
**sampling_args,
)
else:
response = await client.chat.completions.create(
model=model,
messages=prompt,
**clean_sampling_args,
**sampling_args,
)
return response
elif message_type == "completion":
Expand All @@ -399,26 +474,72 @@ async def get_model_response(
)
assert isinstance(prompt, str)
response = await client.completions.create(
model=model, prompt=prompt, **clean_sampling_args
model=model, prompt=prompt, **sampling_args
)
return response
except Exception as e:
# in case of making a request with an overlong prompt, e.g from a too-long
# environment response, we return a dummy response to with finish_reason "length"
if isinstance(e, BadRequestError):
error_text = e.response.text.lower()
context_length_phrases = [
"this model's maximum context length is",
"is longer than the model's context length",
"exceeds the model's context length",
"exceed the configured limit",
"exceeds the configured limit",
"exceeded model",
]
if any(phrase in error_text for phrase in context_length_phrases):
self.logger.debug("Caught overlong prompt.")
raise vf.OverlongPromptError(e)
raise vf.ModelError(e)

@handle_overlong_prompt
async def get_model_response_with_tokens(
client: AsyncOpenAI,
model: str,
prompt: Messages,
prompt_ids: list[int],
oai_tools: list[ChatCompletionToolParam] | None,
sampling_args: SamplingArgs,
message_type: MessageType,
) -> ModelResponse:
"""
Get a model response with pre-tokenized prompt from custom
/v1/chat/completions/tokens endpoint (only available in PRIME-RL's
vLLM server extension)
"""
assert message_type == "chat", (
"get_model_response_with_tokens is only supported for chat tasks."
)

extra_body = sampling_args.pop("extra_body", {})
body = dict(
model=model,
messages=prompt,
tools=oai_tools,
tokens=prompt_ids,
**sampling_args,
**extra_body,
)

return await client.post(
"/chat/completions/tokens",
body=body,
cast_to=ChatCompletion,
)

client, model, oai_tools, sampling_args, message_type = resolve_optional_args(
client, model, oai_tools, sampling_args, message_type
)
sampling_args = normalize_sampling_args(sampling_args)
if self.interleaved_rollouts:
sampling_args = prepare_sampling_args_for_token_prompts(sampling_args)

if self.interleaved_rollouts and len(state["trajectory"]) > 0:
prompt_ids = await get_prompt_ids(state, prompt, client)
return await get_model_response_with_tokens(
client=client,
model=model,
prompt=prompt,
prompt_ids=prompt_ids,
oai_tools=oai_tools,
sampling_args=sampling_args,
message_type=message_type,
)
else:
return await get_model_response_with_messages(
client=client,
model=model,
prompt=prompt,
oai_tools=oai_tools,
sampling_args=sampling_args,
message_type=message_type,
)

async def init_state(
self,
Expand Down Expand Up @@ -896,10 +1017,34 @@ def evaluate_sync(
save_every=save_every,
)

def set_kwargs(self, **kwargs) -> None:
"""
Set environment attributes, using setter methods when available.

For each kwarg, checks if a `set_{key}` method exists and calls it,
otherwise falls back to setattr. This ensures proper propagation for
attributes like `interleaved_rollouts` in EnvGroup.
"""
for key, value in kwargs.items():
setter_name = f"set_{key}"
setter = getattr(self, setter_name, None)
if setter is not None and callable(setter):
setter(value)
else:
setattr(self, key, value)

def set_max_seq_len(self, max_seq_len: int | None) -> None:
"""Set the maximum sequence length for this environment."""
self.max_seq_len = max_seq_len

def set_interleaved_rollouts(self, interleaved_rollouts: bool) -> None:
"""Set the interleaved rollouts flag for this environment."""
self.interleaved_rollouts = interleaved_rollouts
if self.interleaved_rollouts:
self.logger.warning(
f"{self.__class__.__name__} is configured to use interleaved rollouts. All model responses after the first turn will be pre-tokenized before being sent to the model. Currently, this is a hand-crafted feature for PRIME-RL's vLLM server extension."
)

make_dataset = make_dataset


Expand Down
8 changes: 8 additions & 0 deletions verifiers/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def main():
default="",
help="Name of dataset to save to Hugging Face Hub",
)
parser.add_argument(
"--extra-env-kwargs",
"-x",
type=json.loads,
default={},
help='Extra environment as JSON object (e.g., \'{"key": "value", "num": 42}\'). Passed to environment constructor.',
)
args = parser.parse_args()

setup_logging("DEBUG" if args.verbose else os.getenv("VF_LOG_LEVEL", "INFO"))
Expand Down Expand Up @@ -296,6 +303,7 @@ def main():
env_id=args.env_id,
env_args=args.env_args,
env_dir_path=args.env_dir_path,
extra_env_kwargs=args.extra_env_kwargs,
# evaluation
model=args.model,
client_config=client_config,
Expand Down
5 changes: 3 additions & 2 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class State(dict):
INPUT_FIELDS = ["prompt", "answer", "task", "info", "example_id"]
# rollout inputs
input: RolloutInput
client: AsyncOpenAI | None
model: str | None
client: AsyncOpenAI
model: str
sampling_args: SamplingArgs | None
# created during rollout
is_completed: bool
Expand Down Expand Up @@ -228,6 +228,7 @@ class EvalConfig(BaseModel):
max_concurrent: int
max_concurrent_generation: int | None = None
max_concurrent_scoring: int | None = None
extra_env_kwargs: dict = {}
# logging
print_results: bool = False
verbose: bool = False
Expand Down
5 changes: 5 additions & 0 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
# load environment
vf_env = vf.load_environment(env_id=config.env_id, **config.env_args)

# set extra environment kwargs
if config.extra_env_kwargs:
logger.info(f"Setting extra environment kwargs: {config.extra_env_kwargs}")
vf_env.set_kwargs(**config.extra_env_kwargs)

# run evaluation
results_path = get_eval_results_path(config)
logger.info(f"Starting evaluation with model: {config.model}")
Expand Down
Loading
Loading