Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions verifiers/envs/integrations/textarena_env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import random
from copy import deepcopy
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
self.game = game
self.ta_env = ta.make(env_id=game)
self.ta_env.reset(num_players=1)
self.shared_memo = self.build_shared_memo(self.ta_env)
self.num_train_examples = num_train_examples
self.num_eval_examples = num_eval_examples
self.seed = seed
Expand All @@ -76,9 +78,31 @@ def __init__(
**kwargs,
)

@staticmethod
def build_shared_memo(ta_env) -> dict:
"""Build deepcopy memo to share immutable data across env copies.

The textarena EnglishDictionary holds ~430K strings in 4 sets (~38MB).
These are read-only after construction, so sharing them via the memo
dict avoids copying them on every rollout (~120ms and ~38MB saved each).
"""
memo: dict = {}
env = ta_env
while hasattr(env, "env"):
env = env.env
# Share the dictionary object (contains uk_words, us_words, nltk_words sets)
dictionary = getattr(env, "dictionary", None)
if dictionary is not None:
memo[id(dictionary)] = dictionary
# Share the word list (small but also immutable)
word_list = getattr(env, "word_list", None)
if word_list is not None:
memo[id(word_list)] = word_list
return memo

async def setup_state(self, state: vf.State, **kwargs) -> vf.State:
ta_env = deepcopy(self.ta_env)
ta_env.state.game_state["secret_word"] = state["answer"]
ta_env = await asyncio.to_thread(deepcopy, self.ta_env, self.shared_memo.copy())
ta_env.state.game_state["secret_word"] = state["answer"] # type: ignore[unresolved-attribute]
state["ta_env"] = ta_env
return state

Expand All @@ -92,22 +116,20 @@ async def env_response(
ta_env = state["ta_env"]
guess = self.parser.parse_answer(messages)
self.logger.debug(f"Parsed {guess=}")
ta_env.step(str(guess))
await asyncio.to_thread(ta_env.step, str(guess))

if ta_env.state.done:
self.logger.debug(f"Game completed! {ta_env.state.game_info=}")
response = cast(
vf.Messages,
[{"role": "user", "content": ta_env.state.game_info[0]["reason"]}],
)
state["final_env_response"] = response
return response
response = vf.UserMessage(content=ta_env.state.game_info[0]["reason"])
state["final_env_response"] = [response]
return [response]
else:
_, observation = ta_env.get_observation()
_, observation = await asyncio.to_thread(ta_env.get_observation)
self.logger.debug(f"Got {observation=}")
feedback = self.feedback_fn(observation)
self.logger.debug(f"Parsed {feedback=}")
return cast(vf.Messages, [{"role": "user", "content": str(feedback)}])
response = vf.UserMessage(content=str(feedback))
return [response]

def ta_to_hf(self) -> tuple[Dataset, Dataset | None]:
dataset_rows = []
Expand Down
Loading