Skip to content
43 changes: 42 additions & 1 deletion verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from openai import AsyncOpenAI

import verifiers as vf
from verifiers.types import RolloutInput, SamplingArgs
from verifiers.types import ClientConfig, RolloutInput, SamplingArgs
from verifiers.workers.client.env_client import EnvClient

if TYPE_CHECKING:
from datasets import Dataset
Expand Down Expand Up @@ -263,6 +264,46 @@ def add_example_id(example, i):
)
return dataset

@final
async def run_rollout( # type: ignore[override]
self,
input: RolloutInput,
client: AsyncOpenAI | ClientConfig,
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> vf.RolloutOutput:
env = self.get_env_for_task(input["task"])
env_client = env_client or env.env_client or self.env_client
return await env.run_rollout(
input, client, model, sampling_args, max_retries, state_columns, env_client
)

@final
async def run_group( # type: ignore[override]
self,
group_inputs: list[RolloutInput],
client: AsyncOpenAI | ClientConfig,
model: str,
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> list[vf.RolloutOutput]:
env = self.get_env_for_task(group_inputs[0]["task"])
env_client = env_client or env.env_client or self.env_client
return await env.run_group(
group_inputs,
client,
model,
sampling_args,
max_retries,
state_columns,
env_client,
)

@final
async def rollout(
self,
Expand Down
14 changes: 10 additions & 4 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,19 +845,21 @@ async def run_rollout(
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
) -> RolloutOutput:
"""Generate and, optionally, score a rollout."""

resolved_client_config: ClientConfig | None = None
if isinstance(client, ClientConfig):
resolved_client_config = resolve_client_config(client)

if self.env_client is not None: # in server mode
env_client = env_client or self.env_client
if env_client is not None: # in server mode
if resolved_client_config is None:
raise ValueError(
f"client must be have type ClientConfig in server mode, got {type(client)}"
)
return await self.env_client.run_rollout(
return await env_client.run_rollout(
input,
resolved_client_config,
model,
Expand Down Expand Up @@ -901,6 +903,7 @@ async def run_group(
sampling_args: SamplingArgs,
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
**kwargs,
) -> list[RolloutOutput]:
"""Generate and, optionally, score one group."""
Expand All @@ -909,12 +912,13 @@ async def run_group(
if isinstance(client, ClientConfig):
resolved_client_config = resolve_client_config(client)

if self.env_client is not None: # in server mode
env_client = env_client or self.env_client
if env_client is not None: # in server mode
if resolved_client_config is None:
raise ValueError(
f"client must be have type ClientConfig in server mode, got {type(client)}"
)
return await self.env_client.run_group(
return await env_client.run_group(
group_inputs,
resolved_client_config,
model,
Expand Down Expand Up @@ -1400,6 +1404,7 @@ async def start_server(
extra_env_kwargs: dict[str, Any] = {},
log_level: str | None = None,
log_file: str | None = None,
log_file_level: str | None = None,
startup_timeout: float = 3600, # 1h
) -> None:
"""Start a ZMQ server process for this environment.
Expand All @@ -1417,6 +1422,7 @@ async def start_server(
extra_env_kwargs,
log_level,
log_file,
log_file_level,
),
kwargs=dict(address=address),
daemon=True, # ensure server process is terminated when parent exits
Expand Down
2 changes: 2 additions & 0 deletions verifiers/envs/integrations/openenv_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,15 @@ async def start_server(
extra_env_kwargs: dict[str, Any] | None = None,
log_level: str | None = None,
log_file: str | None = None,
log_file_level: str | None = None,
startup_timeout: float = 120.0,
) -> None:
await super().start_server(
address=address,
extra_env_kwargs=extra_env_kwargs or {},
log_level=log_level,
log_file=log_file,
log_file_level=log_file_level,
startup_timeout=startup_timeout,
)

Expand Down
33 changes: 4 additions & 29 deletions verifiers/workers/client/zmq_env_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,7 @@ def _fail_all_pending(self, reason: str):
"""Fail all pending futures with the given reason."""
pending_count = len(self.pending)
if pending_count:
self.logger.warning(
"Failing %d pending request(s): %s",
pending_count,
reason,
)
self.logger.warning(f"Failing {pending_count} pending request(s): {reason}")
for _, future in list(self.pending.items()):
if not future.done():
future.set_exception(RuntimeError(reason))
Expand Down Expand Up @@ -100,11 +96,6 @@ async def _receive_loop(self):
try:
response = msgpack.unpackb(response_data, raw=False)
future.set_result(response)
self.logger.debug(
"Resolved request_id=%s (pending=%d)",
request_id,
len(self.pending),
)
except Exception as unpack_error:
# Unpacking failed - fail the specific future
self.logger.error(
Expand All @@ -117,9 +108,7 @@ async def _receive_loop(self):
)
else:
self.logger.warning(
"Received response for unknown request_id=%s (pending=%d)",
request_id,
len(self.pending),
f"Received response for unknown request_id={request_id} (pending={len(self.pending)})"
)

except asyncio.CancelledError:
Expand All @@ -138,7 +127,6 @@ async def _receive_loop(self):
async def _start(self):
self._receiver_task = asyncio.create_task(self._receive_loop())
self.socket.connect(self.address)
self.logger.debug("ZMQ client started")

async def _send_request(
self,
Expand Down Expand Up @@ -170,26 +158,14 @@ async def _send_request(

future: asyncio.Future[dict] = asyncio.Future()
self.pending[request_id] = future
self.logger.debug(
"Sending %s request_id=%s timeout=%.1fs (pending=%d)",
request.request_type,
request_id,
effective_timeout,
len(self.pending),
)

await self.socket.send_multipart([request_id.encode(), payload_bytes])

try:
raw_response = await asyncio.wait_for(future, timeout=effective_timeout)
except asyncio.TimeoutError:
self.pending.pop(request_id, None)
self.logger.error(
"Timed out waiting for request_id=%s type=%s after %.1fs (pending=%d)",
request_id,
request.request_type,
effective_timeout,
len(self.pending),
f"Timed out waiting for request_id={request_id} type={request.request_type} after {effective_timeout:.1f}s (pending={len(self.pending)})"
)
raise TimeoutError(
f"Environment timeout for {request.request_type} request after {effective_timeout}s"
Expand All @@ -199,7 +175,7 @@ async def _send_request(
response = response_type.model_validate(raw_response)

if not response.success:
raise RuntimeError(f"Server error: {response.error}")
raise RuntimeError(response.error)

return response

Expand All @@ -220,4 +196,3 @@ async def close(self) -> None:
# Close socket and terminate context
self.socket.close()
self.ctx.term()
self.logger.debug("ZMQ client closed")
18 changes: 9 additions & 9 deletions verifiers/workers/server/env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
)

self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
self.logger.debug(
self.logger.info(
f"Initializing {self.__class__.__name__} to serve {env_id} ({env_args=}, {extra_env_kwargs=})"
)

Expand All @@ -56,13 +56,13 @@ def __init__(
self.pending_tasks: set[asyncio.Task] = set()

# load environment
with vf.quiet_verifiers():
self.env = vf.load_environment(self.env_id, **self.env_args)
if self.extra_env_kwargs:
self.logger.debug(
f"Setting extra environment kwargs: {self.extra_env_kwargs}"
)
self.env.set_kwargs(**self.extra_env_kwargs)
self.logger.info(f"Loading environment {env_id} with {env_args=}")
self.env = vf.load_environment(self.env_id, **self.env_args)
if self.extra_env_kwargs:
self.logger.info(
f"Setting extra environment kwargs: {self.extra_env_kwargs}"
)
self.env.set_kwargs(**self.extra_env_kwargs)

@abstractmethod
async def run(self, stop_event: asyncio.Event | None = None):
Expand All @@ -81,7 +81,7 @@ async def run_with_graceful_shutdown():
stop_event = asyncio.Event()

def signal_handler(sig):
server.logger.debug(
server.logger.info(
f"Received signal {sig.name}, initiating graceful shutdown"
)
stop_event.set()
Expand Down
20 changes: 8 additions & 12 deletions verifiers/workers/server/zmq_env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, *args, address: str = "tcp://127.0.0.1:5000", **kwargs):
self.socket.bind(self.address)

async def run(self, stop_event: asyncio.Event | None = None):
self.logger.debug(f"{self.__class__.__name__} started on {self.address}")
self.logger.info(f"{self.__class__.__name__} started on {self.address}")

# Create a task to wait for stop signal
stop_task = asyncio.create_task(stop_event.wait()) if stop_event else None
Expand All @@ -39,7 +39,7 @@ async def run(self, stop_event: asyncio.Event | None = None):
while True:
# exit gracefully on stop signal
if stop_event and stop_event.is_set():
self.logger.debug("Stop event received, shutting down gracefully")
self.logger.info("Stop event received, shutting down gracefully")
break

try:
Expand Down Expand Up @@ -77,7 +77,7 @@ async def run(self, stop_event: asyncio.Event | None = None):
async def close(self):
# cancel and await all pending tasks
if self.pending_tasks:
self.logger.debug(f"Cancelling {len(self.pending_tasks)} pending tasks")
self.logger.info(f"Cancelling {len(self.pending_tasks)} pending tasks")
for task in self.pending_tasks:
task.cancel()
await asyncio.gather(*self.pending_tasks, return_exceptions=True)
Expand All @@ -87,7 +87,7 @@ async def close(self):

self.socket.close()
self.ctx.term()
self.logger.debug("Environment server shut down")
self.logger.info("Environment server shut down")

async def _process_request(
self,
Expand All @@ -103,7 +103,6 @@ async def _process_request(
raw = msgpack.unpackb(payload_bytes, raw=False)
request_type = raw.get("request_type")
request_id = raw.get("request_id", request_id)
self.logger.debug(f"Got {request_type} request (request_id={request_id})")

# validate and route to handler
if request_type == "health":
Expand All @@ -122,14 +121,15 @@ async def _process_request(
)

except asyncio.CancelledError:
self.logger.debug(f"Request {request_id} cancelled during shutdown")
return

except Exception as e:
self.logger.error(f"Error processing request: {e}", exc_info=True)
self.logger.error(
f"Error processing request {request_id}: {e}", exc_info=True
)
response = BaseResponse(
success=False,
error=str(e),
error=repr(e),
)

# serialize response using Pydantic
Expand All @@ -146,7 +146,3 @@ async def _process_request(
await self.socket.send_multipart(
[client_id, request_id.encode(), response_bytes]
)

self.logger.debug(
f"Sent {response.__class__.__name__} (request_id={request_id}, {len(response_bytes)} bytes)"
)
Loading