Skip to content

Commit 6cbcd94

Browse files
Update OpenEnv example scripts (#4547)
1 parent 8510589 commit 6cbcd94

File tree

4 files changed

+101
-54
lines changed

4 files changed

+101
-54
lines changed

docs/source/openenv.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ args = GRPOConfig(
9292
args = GRPOConfig(
9393
use_vllm=True,
9494
vllm_mode="server",
95-
vllm_server_url="http://localhost:8000",
95+
vllm_server_base_url="http://localhost:8000",
9696
# ... other args
9797
)
9898

examples/scripts/openenv/catch.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
2525
Usage:
2626
27-
# Start the docker container for the Catch environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
27+
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
2828
```sh
2929
docker run -d -p 8001:8001 registry.hf.space/openenv-openspiel-env:latest
3030
```
@@ -73,9 +73,9 @@ def parse_args():
7373
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
7474
parser.add_argument(
7575
"--env-mode",
76-
choices=["local", "docker", "space"],
77-
default="docker",
78-
help="Where to run the environment: 'local', 'docker', or 'space'.",
76+
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
77+
default="docker-image",
78+
help="Where to run the environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
7979
)
8080
# --- Generation and model config ---
8181
parser.add_argument(
@@ -90,6 +90,9 @@ def parse_args():
9090
default=1000,
9191
help="Number of prompts to use for training dataset.",
9292
)
93+
parser.add_argument(
94+
"--env-image", type=str, default="openspiel-env:latest", help="Docker image for the OpenSpiel environment."
95+
)
9396
parser.add_argument(
9497
"--vllm-mode",
9598
choices=["colocate", "server"],
@@ -183,25 +186,34 @@ def main():
183186
if args.env_mode == "local":
184187
env_url = f"http://{args.env_host}:{args.env_port}"
185188
server_process = start_env_server(args.env_host, args.env_port)
186-
elif args.env_mode == "docker":
189+
elif args.env_mode == "docker-local":
187190
env_url = f"http://{args.env_host}:{args.env_port}"
188191
server_process = None
189-
print(f"🌍 Using existing Docker environment at {env_url}")
192+
print(f"🌍 Using existing OpenSpiel Environment (Docker) at: {env_url}")
193+
elif args.env_mode == "docker-image":
194+
client = OpenSpielEnv.from_docker_image(args.env_image)
195+
server_process = None
196+
print("🌍 Using OpenSpiel Environment (Docker) from local Image")
197+
elif args.env_mode == "docker-hub":
198+
client = OpenSpielEnv.from_hub(args.env_image)
199+
server_process = None
200+
print("🌍 Using existing OpenSpiel Environment (Docker) from Hub Image")
190201
elif args.env_mode == "space":
191202
env_url = args.env_host
192203
server_process = None
193-
print(f"🚀 Using Hugging Face Space environment at {env_url}")
204+
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
194205
else:
195-
raise ValueError(f"Unknown env mode: {args.env_mode}")
206+
raise ValueError(f"Unknown environment mode: {args.env_mode}")
196207

197-
client = OpenSpielEnv(base_url=env_url)
208+
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
209+
client = OpenSpielEnv(base_url=env_url)
198210
dataset = Dataset.from_dict({"prompt": [BASE_PROMPT] * args.dataset_size})
199211

200212
training_args = GRPOConfig(
201213
output_dir=f"{args.model.split('/')[-1]}-GRPO-Catch",
202214
use_vllm=True,
203215
vllm_mode=args.vllm_mode,
204-
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
216+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
205217
logging_steps=1,
206218
report_to="trackio",
207219
num_train_epochs=1,

examples/scripts/openenv/echo.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
2525
Usage:
2626
27-
# Start the docker container for the Echo environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
27+
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
2828
```sh
2929
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
3030
```
@@ -71,9 +71,9 @@ def parse_args():
7171
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
7272
parser.add_argument(
7373
"--env-mode",
74-
choices=["local", "docker", "space"],
75-
default="docker",
76-
help="Where to run the Echo environment: 'local' to launch it, 'docker' if already running, or 'space' to use a remote Space URL.",
74+
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
75+
default="docker-image",
76+
help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
7777
)
7878
parser.add_argument(
7979
"--model",
@@ -87,6 +87,9 @@ def parse_args():
8787
default="trl-lib/ultrafeedback-prompt",
8888
help="Dataset to use for training.",
8989
)
90+
parser.add_argument(
91+
"--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment."
92+
)
9093
parser.add_argument(
9194
"--vllm-mode",
9295
choices=["colocate", "server"],
@@ -146,25 +149,34 @@ def main():
146149
if args.env_mode == "local":
147150
env_url = f"http://{args.env_host}:{args.env_port}"
148151
server_process = start_env_server(args.env_host, args.env_port)
149-
elif args.env_mode == "docker":
152+
elif args.env_mode == "docker-local":
150153
env_url = f"http://{args.env_host}:{args.env_port}"
151154
server_process = None
152155
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
156+
elif args.env_mode == "docker-image":
157+
client = EchoEnv.from_docker_image(args.env_image)
158+
server_process = None
159+
print("🌍 Using Echo Environment (Docker) from local Image")
160+
elif args.env_mode == "docker-hub":
161+
client = EchoEnv.from_hub(args.env_image)
162+
server_process = None
163+
print("🌍 Using existing Echo Environment (Docker) from Hub Image")
153164
elif args.env_mode == "space":
154165
env_url = args.env_host
155166
server_process = None
156-
print(f"🚀 Using Hugging Face Space environment at: {env_url}")
167+
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
157168
else:
158169
raise ValueError(f"Unknown environment mode: {args.env_mode}")
159170

160-
client = EchoEnv(base_url=env_url)
171+
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
172+
client = EchoEnv(base_url=env_url)
161173
dataset = load_dataset(args.dataset, split="train[:1000]")
162174

163175
training_args = GRPOConfig(
164176
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
165177
use_vllm=True,
166178
vllm_mode=args.vllm_mode,
167-
vllm_server_url=args.vllm_server_url if args.vllm_mode == "server" else None,
179+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
168180
logging_steps=1,
169181
report_to="trackio",
170182
num_train_epochs=1,

examples/scripts/openenv/wordle.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
2424
Usage:
2525
26-
# Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space.
26+
# Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script.
2727
```sh
2828
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
2929
# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
@@ -85,11 +85,15 @@ def parse_args() -> argparse.Namespace:
8585
default="Qwen/Qwen3-1.7B",
8686
help="Model identifier passed to GRPOTrainer for fine-tuning.",
8787
)
88+
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the environment server.")
89+
parser.add_argument("--env-port", type=int, default=8001, help="Port for the environment server.")
8890
parser.add_argument(
89-
"--env-url",
90-
default="https://burtenshaw-textarena.hf.space",
91-
help="URL for the TextArena Wordle environment.",
91+
"--env-mode",
92+
choices=["docker-local", "docker-image", "docker-hub", "space"],
93+
default="docker-image",
94+
help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
9295
)
96+
parser.add_argument("--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment.")
9397
parser.add_argument(
9498
"--system-prompt-path",
9599
default="wordle_prompt.txt",
@@ -411,46 +415,65 @@ def reward_repetition(completions: list[str], **kwargs) -> list[float]:
411415

412416

413417
def main() -> None:
414-
cli_args = parse_args()
418+
args = parse_args()
415419

416-
tokenizer = AutoTokenizer.from_pretrained(cli_args.tokenizer_id)
420+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
417421
tokenizer.pad_token = tokenizer.eos_token
418422

419-
env = TextArenaEnv(base_url=cli_args.env_url)
423+
# Select environment mode
424+
if args.env_mode == "docker-local":
425+
env_url = f"http://{args.env_host}:{args.env_port}"
426+
print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}")
427+
elif args.env_mode == "docker-image":
428+
client = TextArenaEnv.from_docker_image(args.env_image)
429+
print(f"🌍 Using TextArena Environment (Docker) from local Image")
430+
elif args.env_mode == "docker-hub":
431+
client = TextArenaEnv.from_hub(args.env_image)
432+
print(f"🌍 Using existing TextArena Environment (Docker) from Hub Image")
433+
elif args.env_mode == "space":
434+
env_url = args.env_host
435+
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
436+
else:
437+
raise ValueError(f"Unknown environment mode: {args.env_mode}")
420438

421-
system_prompt = resolve_system_prompt(cli_args.system_prompt_path)
439+
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
440+
client = TextArenaEnv(base_url=env_url)
422441

423-
dataset = Dataset.from_dict({"prompt": [cli_args.dataset_prompt] * cli_args.dataset_size})
442+
#env = TextArenaEnv(base_url=args.env_url)
443+
444+
system_prompt = resolve_system_prompt(args.system_prompt_path)
445+
446+
dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})
424447

425448
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
426-
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(cli_args.model_id)}-{timestamp}"
427-
output_dir = Path(cli_args.output_dir or default_output_dir)
449+
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
450+
output_dir = Path(args.output_dir or default_output_dir)
428451

429452
grpo_config = GRPOConfig(
430453
use_vllm=True,
431-
vllm_mode=cli_args.vllm_mode,
432-
vllm_server_url=cli_args.vllm_server_url if cli_args.vllm_mode == "server" else None,
454+
vllm_mode=args.vllm_mode,
455+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
433456
output_dir=str(output_dir),
434-
num_train_epochs=cli_args.num_epochs,
435-
learning_rate=cli_args.learning_rate,
436-
weight_decay=cli_args.weight_decay,
437-
gradient_accumulation_steps=cli_args.gradient_accumulation_steps,
438-
per_device_train_batch_size=cli_args.per_device_batch_size,
439-
warmup_steps=cli_args.warmup_steps,
440-
num_generations=cli_args.num_generations,
441-
max_completion_length=cli_args.max_new_tokens,
442-
logging_steps=cli_args.logging_steps,
457+
num_train_epochs=args.num_epochs,
458+
learning_rate=args.learning_rate,
459+
weight_decay=args.weight_decay,
460+
gradient_accumulation_steps=args.gradient_accumulation_steps,
461+
per_device_train_batch_size=args.per_device_batch_size,
462+
warmup_steps=args.warmup_steps,
463+
num_generations=args.num_generations,
464+
max_completion_length=args.max_new_tokens,
465+
logging_steps=args.logging_steps,
443466
save_strategy="steps",
444-
save_steps=cli_args.save_interval,
445-
save_total_limit=cli_args.save_total_limit,
446-
temperature=cli_args.temperature,
447-
top_k=cli_args.top_k,
448-
top_p=cli_args.top_p,
467+
save_steps=args.save_interval,
468+
save_total_limit=args.save_total_limit,
469+
temperature=args.temperature,
470+
top_k=args.top_k,
471+
top_p=args.top_p,
449472
)
450473

451-
grpo_config.run_name = cli_args.run_name or f"run-{timestamp}"
452-
grpo_config.project = cli_args.project or f"group-{sanitize_name(cli_args.model_id)}"
453-
grpo_config.trackio_space_id = cli_args.trackio_space_id
474+
grpo_config.run_name = args.run_name or f"run-{timestamp}"
475+
grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
476+
grpo_config.trackio_space_id = args.trackio_space_id
454477

455478
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
456479
episode_prompt_ids: list[list[int]] = []
@@ -464,11 +487,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
464487
for prompt_text in prompts:
465488
episode = rollout_once(
466489
trainer=trainer,
467-
env=env,
490+
env=client,
468491
tokenizer=tokenizer,
469492
dataset_prompt=prompt_text,
470493
system_prompt=system_prompt,
471-
max_turns=cli_args.max_turns,
494+
max_turns=args.max_turns,
472495
)
473496
episode_prompt_ids.append(episode["prompt_ids"])
474497
episode_completion_ids.append(episode["completion_ids"])
@@ -489,7 +512,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
489512
}
490513

491514
trainer = GRPOTrainer(
492-
model=cli_args.model_id,
515+
model=args.model_id,
493516
processing_class=tokenizer,
494517
reward_funcs=[
495518
reward_correct,
@@ -503,12 +526,12 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
503526
)
504527

505528
print("Starting GRPO training with Wordle environment...")
506-
print(f"Using {cli_args.num_generations} rollouts per dataset prompt")
529+
print(f"Using {args.num_generations} rollouts per dataset prompt")
507530

508531
try:
509532
trainer.train()
510533
finally:
511-
env.close()
534+
client.close()
512535

513536

514537
if __name__ == "__main__":

0 commit comments

Comments
 (0)