Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions environments/openenv_echo/openenv_echo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
from typing import Any, cast

import verifiers as vf
Expand All @@ -10,9 +9,7 @@ def render_echo_prompt(
*,
action_schema: dict[str, Any] | None = None,
context: str = "reset",
**kwargs: Any,
) -> ChatMessages:
del kwargs
if not isinstance(observation, dict):
raise RuntimeError(
f"openenv-echo prompt renderer expected dict observation, got {type(observation).__name__}."
Expand Down Expand Up @@ -50,7 +47,6 @@ def load_environment(
seed: int = 0,
):
return vf.OpenEnvEnv(
openenv_project=Path(__file__).parent / "proj",
num_train_examples=num_train_examples,
num_eval_examples=num_eval_examples,
seed=seed,
Expand Down
4 changes: 0 additions & 4 deletions environments/openenv_textarena/openenv_textarena.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pathlib import Path
import re
from typing import Any, cast

Expand Down Expand Up @@ -39,9 +38,7 @@ def render_textarena_prompt(
observation: Any,
*,
context: str = "reset",
**kwargs: Any,
) -> ChatMessages:
del kwargs
if not isinstance(observation, dict):
raise RuntimeError(
f"openenv-textarena prompt renderer expected dict observation, got {type(observation).__name__}."
Expand Down Expand Up @@ -72,7 +69,6 @@ def load_environment(
seed: int = 0,
):
return vf.OpenEnvEnv(
openenv_project=Path(__file__).parent / "proj",
num_train_examples=num_train_examples,
num_eval_examples=num_eval_examples,
seed=seed,
Expand Down
5 changes: 1 addition & 4 deletions verifiers/envs/integrations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ uv run vf-build my-openenv

```python
# environments/my_openenv/my_openenv.py
from pathlib import Path
from typing import Any
import verifiers as vf
from verifiers.envs.integrations.openenv_env import OpenEnvEnv

def render_prompt(observation: Any, **kwargs: Any) -> list[dict[str, str]]:
del kwargs
def render_prompt(observation: Any) -> list[dict[str, str]]:
if not isinstance(observation, dict):
raise RuntimeError("Expected dict observation")
prompt = observation.get("prompt")
Expand All @@ -181,7 +179,6 @@ def load_environment(
seed: int = 0,
) -> vf.Environment:
return OpenEnvEnv(
openenv_project=Path(__file__).parent / "proj",
prompt_renderer=render_prompt,
num_train_examples=num_train_examples,
num_eval_examples=num_eval_examples,
Expand Down
16 changes: 14 additions & 2 deletions verifiers/envs/integrations/openenv_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class OpenEnvEnv(vf.MultiTurnEnv):

def __init__(
self,
openenv_project: str | Path,
openenv_project: str | Path | None = None,
num_train_examples: int = 100,
num_eval_examples: int = 50,
seed: int = 0,
Expand All @@ -93,7 +93,7 @@ def __init__(
jitter: float = 1e-3,
**kwargs: Any,
):
self.openenv_project = str(openenv_project)
self.openenv_project = self._resolve_openenv_project(openenv_project)
self.num_train_examples = num_train_examples
self.num_eval_examples = num_eval_examples
self.seed = seed
Expand Down Expand Up @@ -139,6 +139,18 @@ def __init__(
**kwargs,
)

def _resolve_openenv_project(self, openenv_project: str | Path | None) -> str:
if openenv_project is not None:
return str(openenv_project)

current_file = Path(__file__).resolve()
for frame_info in inspect.stack()[1:]:
frame_path = Path(frame_info.filename).resolve()
if frame_path != current_file:
return str(frame_path.parent / "proj")

return str(Path.cwd() / "proj")

async def start_server(
self,
address: str | None = None,
Expand Down
5 changes: 1 addition & 4 deletions verifiers/scripts/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,12 @@ def load_environment(**kwargs) -> vf.Environment:
'''

OPENENV_ENVIRONMENT_TEMPLATE = """\
from pathlib import Path
from typing import Any

import verifiers as vf


def render_prompt(observation: Any, **kwargs: Any) -> list[dict[str, Any]]:
del kwargs
def render_prompt(observation: Any) -> list[dict[str, Any]]:
if isinstance(observation, dict):
messages = observation.get("messages")
if isinstance(messages, list) and messages:
Expand All @@ -178,7 +176,6 @@ def load_environment(
seed: int = 0,
):
return vf.OpenEnvEnv(
openenv_project=Path(__file__).parent / "proj",
num_train_examples=num_train_examples,
num_eval_examples=num_eval_examples,
seed=seed,
Expand Down
Loading