2323
2424Usage:
2525
26- # Start the docker container for the Wordle environment (recommended). Alternatively, you can run it locally or directly from a HF Space .
26+ # Start the environment only if using --env-mode docker-local; In other modes, the env is automatically managed by the script .
2727```sh
2828docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest
2929# or TEXTARENA_ENV_ID=Wordle-v0 TEXTARENA_NUM_PLAYERS=1 python -m src.envs.textarena_env.server.app
@@ -85,11 +85,15 @@ def parse_args() -> argparse.Namespace:
8585 default = "Qwen/Qwen3-1.7B" ,
8686 help = "Model identifier passed to GRPOTrainer for fine-tuning." ,
8787 )
88+ parser .add_argument ("--env-host" , type = str , default = "0.0.0.0" , help = "Host for the environment server." )
89+ parser .add_argument ("--env-port" , type = int , default = 8001 , help = "Port for the environment server." )
8890 parser .add_argument (
89- "--env-url" ,
90- default = "https://burtenshaw-textarena.hf.space" ,
91- help = "URL for the TextArena Wordle environment." ,
91+ "--env-mode" ,
92+ choices = ["docker-local" , "docker-image" , "docker-hub" , "space" ],
93+ default = "docker-image" ,
94+ help = "Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL." ,
9295 )
96+ parser .add_argument ("--env-image" , type = str , default = "textarena-env:latest" , help = "Docker image for the TextArena environment." )
9397 parser .add_argument (
9498 "--system-prompt-path" ,
9599 default = "wordle_prompt.txt" ,
@@ -411,46 +415,65 @@ def reward_repetition(completions: list[str], **kwargs) -> list[float]:
411415
412416
413417def main () -> None :
414- cli_args = parse_args ()
418+ args = parse_args ()
415419
416- tokenizer = AutoTokenizer .from_pretrained (cli_args .tokenizer_id )
420+ tokenizer = AutoTokenizer .from_pretrained (args .tokenizer_id )
417421 tokenizer .pad_token = tokenizer .eos_token
418422
419- env = TextArenaEnv (base_url = cli_args .env_url )
423+ # Select environment mode
424+ if args .env_mode == "docker-local" :
425+ env_url = f"http://{ args .env_host } :{ args .env_port } "
426+ print (f"🌍 Using existing TextArena Environment (Docker) at: { env_url } " )
427+ elif args .env_mode == "docker-image" :
428+ client = TextArenaEnv .from_docker_image (args .env_image )
429+ print (f"🌍 Using TextArena Environment (Docker) from local Image" )
430+ elif args .env_mode == "docker-hub" :
431+ client = TextArenaEnv .from_hub (args .env_image )
432+ print (f"🌍 Using existing TextArena Environment (Docker) from Hub Image" )
433+ elif args .env_mode == "space" :
434+ env_url = args .env_host
435+ print (f"🌍 Using Hugging Face Space environment at: { env_url } " )
436+ else :
437+ raise ValueError (f"Unknown environment mode: { args .env_mode } " )
420438
421- system_prompt = resolve_system_prompt (cli_args .system_prompt_path )
439+ if args .env_mode != "docker-hub" and args .env_mode != "docker-image" :
440+ client = TextArenaEnv (base_url = env_url )
422441
423- dataset = Dataset .from_dict ({"prompt" : [cli_args .dataset_prompt ] * cli_args .dataset_size })
442+ #env = TextArenaEnv(base_url=args.env_url)
443+
444+ system_prompt = resolve_system_prompt (args .system_prompt_path )
445+
446+ dataset = Dataset .from_dict ({"prompt" : [args .dataset_prompt ] * args .dataset_size })
424447
425448 timestamp = datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
426- default_output_dir = Path ("outputs" ) / f"wordle-grpo-{ sanitize_name (cli_args .model_id )} -{ timestamp } "
427- output_dir = Path (cli_args .output_dir or default_output_dir )
449+ default_output_dir = Path ("outputs" ) / f"wordle-grpo-{ sanitize_name (args .model_id )} -{ timestamp } "
450+ output_dir = Path (args .output_dir or default_output_dir )
428451
429452 grpo_config = GRPOConfig (
430453 use_vllm = True ,
431- vllm_mode = cli_args .vllm_mode ,
432- vllm_server_url = cli_args .vllm_server_url if cli_args .vllm_mode == "server" else None ,
454+ vllm_mode = args .vllm_mode ,
455+ vllm_server_base_url = args .vllm_server_url if args .vllm_mode == "server" else None ,
433456 output_dir = str (output_dir ),
434- num_train_epochs = cli_args .num_epochs ,
435- learning_rate = cli_args .learning_rate ,
436- weight_decay = cli_args .weight_decay ,
437- gradient_accumulation_steps = cli_args .gradient_accumulation_steps ,
438- per_device_train_batch_size = cli_args .per_device_batch_size ,
439- warmup_steps = cli_args .warmup_steps ,
440- num_generations = cli_args .num_generations ,
441- max_completion_length = cli_args .max_new_tokens ,
442- logging_steps = cli_args .logging_steps ,
457+ num_train_epochs = args .num_epochs ,
458+ learning_rate = args .learning_rate ,
459+ weight_decay = args .weight_decay ,
460+ gradient_accumulation_steps = args .gradient_accumulation_steps ,
461+ per_device_train_batch_size = args .per_device_batch_size ,
462+ warmup_steps = args .warmup_steps ,
463+ num_generations = args .num_generations ,
464+ max_completion_length = args .max_new_tokens ,
465+ logging_steps = args .logging_steps ,
443466 save_strategy = "steps" ,
444- save_steps = cli_args .save_interval ,
445- save_total_limit = cli_args .save_total_limit ,
446- temperature = cli_args .temperature ,
447- top_k = cli_args .top_k ,
448- top_p = cli_args .top_p ,
467+ save_steps = args .save_interval ,
468+ save_total_limit = args .save_total_limit ,
469+ temperature = args .temperature ,
470+ top_k = args .top_k ,
471+ top_p = args .top_p ,
449472 )
450473
451- grpo_config .run_name = cli_args .run_name or f"run-{ timestamp } "
452- grpo_config .project = cli_args .project or f"group-{ sanitize_name (cli_args .model_id )} "
453- grpo_config .trackio_space_id = cli_args .trackio_space_id
474+ grpo_config .run_name = args .run_name or f"run-{ timestamp } "
475+ grpo_config .project = args .project or f"group-{ sanitize_name (args .model_id )} "
476+ grpo_config .trackio_space_id = args .trackio_space_id
454477
455478 def rollout_func (prompts : list [str ], trainer : GRPOTrainer ) -> dict [str , list ]:
456479 episode_prompt_ids : list [list [int ]] = []
@@ -464,11 +487,11 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
464487 for prompt_text in prompts :
465488 episode = rollout_once (
466489 trainer = trainer ,
467- env = env ,
490+ env = client ,
468491 tokenizer = tokenizer ,
469492 dataset_prompt = prompt_text ,
470493 system_prompt = system_prompt ,
471- max_turns = cli_args .max_turns ,
494+ max_turns = args .max_turns ,
472495 )
473496 episode_prompt_ids .append (episode ["prompt_ids" ])
474497 episode_completion_ids .append (episode ["completion_ids" ])
@@ -489,7 +512,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
489512 }
490513
491514 trainer = GRPOTrainer (
492- model = cli_args .model_id ,
515+ model = args .model_id ,
493516 processing_class = tokenizer ,
494517 reward_funcs = [
495518 reward_correct ,
@@ -503,12 +526,12 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
503526 )
504527
505528 print ("Starting GRPO training with Wordle environment..." )
506- print (f"Using { cli_args .num_generations } rollouts per dataset prompt" )
529+ print (f"Using { args .num_generations } rollouts per dataset prompt" )
507530
508531 try :
509532 trainer .train ()
510533 finally :
511- env .close ()
534+ client .close ()
512535
513536
514537if __name__ == "__main__" :
0 commit comments