Skip to content
Merged
169 changes: 111 additions & 58 deletions tests/test_rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import math
import signal
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -37,6 +38,7 @@ def mock_sandbox_client():
client.execute_command = AsyncMock(return_value=MagicMock(stdout="", stderr=""))
client.upload_file = AsyncMock()
client.upload_bytes = AsyncMock()
client.teardown = MagicMock()
return client


Expand All @@ -58,6 +60,8 @@ def rlm_env(mock_sandbox_client, mock_dataset):
with (
patch("verifiers.envs.sandbox_env.AsyncSandboxClient") as mock_client_cls,
patch("verifiers.envs.sandbox_env.CreateSandboxRequest"),
# Avoid registering SIGTERM handlers that call exit(143) during tests.
patch("verifiers.envs.environment.signal.signal"),
):
mock_client_cls.return_value = mock_sandbox_client
env = RLMEnv(
Expand Down Expand Up @@ -86,6 +90,8 @@ def another_tool(text: str) -> str:
with (
patch("verifiers.envs.sandbox_env.AsyncSandboxClient") as mock_client_cls,
patch("verifiers.envs.sandbox_env.CreateSandboxRequest"),
# Avoid registering SIGTERM handlers that call exit(143) during tests.
patch("verifiers.envs.environment.signal.signal"),
):
mock_client_cls.return_value = mock_sandbox_client
env = RLMEnv(
Expand All @@ -105,6 +111,8 @@ def rlm_env_local(mock_sandbox_client, mock_dataset):
with (
patch("verifiers.envs.sandbox_env.AsyncSandboxClient") as mock_client_cls,
patch("verifiers.envs.sandbox_env.CreateSandboxRequest"),
# Avoid registering SIGTERM handlers that call exit(143) during tests.
patch("verifiers.envs.environment.signal.signal"),
):
mock_client_cls.return_value = mock_sandbox_client
env = RLMEnv(
Expand Down Expand Up @@ -707,13 +715,61 @@ async def test_local_worker_exports_stagger_env_vars(rlm_env_local, tmp_path):
patch.object(executor, "_wait_for_ready", new=AsyncMock()),
patch("verifiers.envs.experimental.rlm_env.subprocess.Popen") as mock_popen,
):
mock_popen.return_value = MagicMock()
mock_popen.return_value = MagicMock(wait=MagicMock(), pid=1234)
await executor._start_worker(state, session)

_, kwargs = mock_popen.call_args
env = kwargs["env"]
assert env["RLM_SUB_LLM_STAGGER_MS"] == "18"
assert env["RLM_SUB_LLM_STAGGER_JITTER_MS"] == "9"
session.worker_process = None
executor._sessions.pop(session.rollout_id, None)
session.temp_dir.cleanup()


@pytest.mark.asyncio
async def test_local_worker_starts_new_session(rlm_env_local, tmp_path):
executor = rlm_env_local._executor
state = {
"rollout_id": "rlm_test_start_session",
"interception_url": "http://test",
"model": "test-model",
}
session = executor._get_or_create_session(state)
session.venv_path = str(tmp_path / "venv")

with (
patch.object(executor, "_venv_python", return_value="python"),
patch.object(executor, "_wait_for_ready", new=AsyncMock()),
patch("verifiers.envs.experimental.rlm_env.subprocess.Popen") as mock_popen,
):
mock_popen.return_value = MagicMock()
await executor._start_worker(state, session)

_, kwargs = mock_popen.call_args
assert kwargs["start_new_session"] is True
session.worker_process = None
executor._sessions.pop(session.rollout_id, None)
session.temp_dir.cleanup()


def test_local_worker_stop_kills_process_group(rlm_env_local):
executor = rlm_env_local._executor
state = {"rollout_id": "rlm_test_stop_session"}
session = executor._get_or_create_session(state)
process = MagicMock()
process.pid = 4242
process.wait = MagicMock()
session.worker_process = process

with patch("verifiers.envs.experimental.rlm_env.os.killpg") as mock_killpg:
executor._stop_worker(session)

mock_killpg.assert_called_once_with(4242, signal.SIGTERM)
process.wait.assert_called_once()
session.worker_process = None
executor._sessions.pop(session.rollout_id, None)
session.temp_dir.cleanup()


# =============================================================================
Expand Down Expand Up @@ -1394,8 +1450,9 @@ async def test_completes_without_tool_calls(self, rlm_env_with_sub_tools):
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)

messages = [{"role": "user", "content": "Test"}]
state = {}
result = await rlm_env_with_sub_tools._run_sub_llm(
mock_client, "gpt-4", messages
state, mock_client, "gpt-4", messages
)

assert result["final_content"] == "Final answer"
Expand Down Expand Up @@ -1453,7 +1510,8 @@ async def test_executes_tool_calls(self, rlm_env_with_sub_tools):
)

messages = [{"role": "user", "content": "Add 2 and 3"}]
await rlm_env_with_sub_tools._run_sub_llm(mock_client, "gpt-4", messages)
state = {}
await rlm_env_with_sub_tools._run_sub_llm(state, mock_client, "gpt-4", messages)

assert mock_client.chat.completions.create.call_count == 2

Expand Down Expand Up @@ -1509,7 +1567,8 @@ async def test_respects_max_turns_limit(self, rlm_env_with_sub_tools):
mock_client.chat.completions.create = AsyncMock(side_effect=responses)

messages = [{"role": "user", "content": "Test"}]
await rlm_env_with_sub_tools._run_sub_llm(mock_client, "gpt-4", messages)
state = {}
await rlm_env_with_sub_tools._run_sub_llm(state, mock_client, "gpt-4", messages)

# Should be max_turns + 1 (final call without tools)
assert (
Expand All @@ -1519,76 +1578,67 @@ async def test_respects_max_turns_limit(self, rlm_env_with_sub_tools):


# =============================================================================
# 7. Sub-LLM Logprobs Handling
# 7. Sub-LLM Request Paths
# =============================================================================


class TestSubLLMLogprobs:
"""Tests for lazy logprobs detection in sub-LLM calls."""
class TestSubLLMRequestPaths:
"""Tests for sub-LLM request routing."""

@pytest.mark.asyncio
async def test_lazy_logprobs_fallback_on_param_error(self, rlm_env):
"""Retries without logprobs and marks support False on param error."""
async def test_interleaved_uses_tokens_endpoint(self, rlm_env):
"""Uses /chat/completions/tokens when interleaved_rollouts is True."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(
side_effect=[
Exception("Invalid request: logprobs not supported for this model"),
mock_response,
]
)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.chat.completions.create = AsyncMock()

rlm_env._sub_llm_supports_logprobs = None
rlm_env.interleaved_rollouts = True
messages = [{"role": "user", "content": "hi"}]
result = await rlm_env._call_sub_llm_api(mock_client, "gpt-4", messages)

assert result is mock_response
assert rlm_env._sub_llm_supports_logprobs is False
calls = mock_client.chat.completions.create.call_args_list
assert calls[0].kwargs["logprobs"] is True
assert calls[1].kwargs["logprobs"] is None
state = {"sampling_args": {"max_tokens": 7, "extra_body": {"foo": "bar"}}}

with patch(
"verifiers.envs.experimental.rlm_env.tokenize_vllm",
new=AsyncMock(return_value=[1, 2, 3]),
) as mock_tokenize:
await rlm_env._call_sub_llm_api(state, mock_client, "gpt-4", messages)

mock_tokenize.assert_awaited_once_with(
client=mock_client,
messages=messages,
tools=None,
model="gpt-4",
)
mock_client.post.assert_awaited_once()
args, kwargs = mock_client.post.call_args
assert args[0] == "/chat/completions/tokens"
body = kwargs["body"]
assert body["tokens"] == [1, 2, 3]
assert body["max_completion_tokens"] == 7
assert body["return_token_ids"] is True
assert body["foo"] == "bar"
assert "max_tokens" not in body
mock_client.chat.completions.create.assert_not_called()

@pytest.mark.asyncio
async def test_lazy_logprobs_success_sets_true(self, rlm_env):
"""Sets support True when the first logprobs call succeeds."""
async def test_non_interleaved_uses_chat_completions(self, rlm_env):
"""Uses chat.completions.create when interleaved_rollouts is False."""
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock()

rlm_env._sub_llm_supports_logprobs = None
rlm_env.interleaved_rollouts = False
messages = [{"role": "user", "content": "hi"}]
result = await rlm_env._call_sub_llm_api(mock_client, "gpt-4", messages)

assert result is mock_response
assert rlm_env._sub_llm_supports_logprobs is True
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert call_kwargs["logprobs"] is True

@pytest.mark.asyncio
async def test_lazy_logprobs_fallback_if_flag_flips(self, rlm_env):
"""Retries without logprobs even if another call flips the flag."""
mock_client = MagicMock()
mock_response = MagicMock()
call_count = 0
state = {"sampling_args": {"max_tokens": 7}}
with patch(
"verifiers.envs.experimental.rlm_env.tokenize_vllm", new=AsyncMock()
) as mock_tokenize:
await rlm_env._call_sub_llm_api(state, mock_client, "gpt-4", messages)

async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
rlm_env._sub_llm_supports_logprobs = False
raise Exception("logprobs not supported")
return mock_response

mock_client.chat.completions.create = AsyncMock(side_effect=side_effect)

rlm_env._sub_llm_supports_logprobs = None
messages = [{"role": "user", "content": "hi"}]
result = await rlm_env._call_sub_llm_api(mock_client, "gpt-4", messages)

assert result is mock_response
calls = mock_client.chat.completions.create.call_args_list
assert calls[0].kwargs["logprobs"] is True
assert calls[1].kwargs["logprobs"] is None
mock_client.chat.completions.create.assert_awaited_once()
mock_client.post.assert_not_called()
mock_tokenize.assert_not_called()


# =============================================================================
Expand Down Expand Up @@ -1643,6 +1693,7 @@ async def test_routes_to_correct_model(self, rlm_env):
"client": mock_client,
"model": "test-model",
"sub_model": "gpt-4",
"state": {},
}

mock_request = MagicMock()
Expand Down Expand Up @@ -1681,6 +1732,7 @@ async def test_uses_tool_loop_when_configured(self, rlm_env_with_sub_tools):
"client": mock_client,
"model": "test-model",
"sub_model": "gpt-4",
"state": {},
}

mock_request = MagicMock()
Expand Down Expand Up @@ -1818,8 +1870,9 @@ async def test_accumulates_tokens_across_tool_turns(self, rlm_env_with_sub_tools
)

messages = [{"role": "user", "content": "Add 2 and 3"}]
state = {}
result = await rlm_env_with_sub_tools._run_sub_llm(
mock_client, "gpt-4", messages
state, mock_client, "gpt-4", messages
)

# Should accumulate tokens from both calls
Expand Down
Loading
Loading