-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add grpo job example #3589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add grpo job example #3589
Conversation
7193c51
to
7e212f5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a GRPO job example with trainer scripts, configs, callbacks, environment setup, and sample datasets for a medical MCQA task in AzureML.
- Adds reward functions and trainer code for reasoning tasks
- Introduces YAML/JSON configs and Docker/requirements for environment
- Includes sample datasets (50 records each) and AML setup script
Reviewed Changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
sdk/python/jobs/grpo/src/grpo_trainer_rewards.py | Implements format and accuracy reward functions |
sdk/python/jobs/grpo/src/BldDemo_Reasoning_Train.py | Main GRPO training script with CLI arg parsing |
sdk/python/jobs/grpo/aml_setup.py | AML workspace setup for dataset, model, compute, env |
sdk/python/jobs/grpo/src/grpo_trainer_callbacks.py | Callback to save HF-transformers models in MLflow |
sdk/python/jobs/grpo/src/grpo_trainer_config.yaml | Trainer configuration for vLLM, reward weighting |
sdk/python/jobs/grpo/environment/Dockerfile | Docker image build and package installs |
sdk/python/jobs/grpo/environment/requirements.txt | Python dependencies |
sdk/python/jobs/grpo/datasets/med_mcqa/*.jsonl | Sample dataset splits for training/eval/testing |
Comments suppressed due to low confidence (4)
sdk/python/jobs/grpo/aml_setup.py:13
- The
Model
import at line 13 duplicates the earlier import fromazure.ai.ml.entities
. Remove one of the imports to prevent confusion and improve code clarity.
from mlflow.models import Model
sdk/python/jobs/grpo/aml_setup.py:100
- Fix the typo in the comment: change "falsh" to "flash".
# This job requires falsh attention and needs A100 or H100 GPUs.
sdk/python/jobs/grpo/src/grpo_trainer_rewards.py:14
- There are no tests covering the new reward functions (
format_reward
and_medmcqa_match_fn
). Consider adding unit tests to validate correct behavior and edge cases.
def format_reward(completions, **kwargs):
sdk/python/jobs/grpo/src/BldDemo_Reasoning_Train.py:186
- The
SaveMLflowModelCallback
constructor does not accept apreprocessor
argument but is called with it here, leading to a runtimeTypeError
. Change the argument name or update the callback signature to match.
preprocessor=tokenizer,
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a GRPO example for running a policy optimization job on AzureML, including new reward functions, callbacks, dataset samples, environment specs, and setup scripts.
- Introduce reward computation and accuracy matching functions.
- Provide a full training script and AML workspace setup.
- Add sample datasets, Dockerfile, and requirements for reproducible runs.
Reviewed Changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
sdk/python/jobs/grpo/src/grpo_trainer_rewards.py | Implemented format and accuracy reward functions |
sdk/python/jobs/grpo/src/grpo_trainer_config.yaml | Added YAML config for GRPO training parameters |
sdk/python/jobs/grpo/src/grpo_trainer_callbacks.py | Added MLflow save callback for model artifact management |
sdk/python/jobs/grpo/src/BldDemo_Reasoning_Train.py | Example training script using GRPOTrainer |
sdk/python/jobs/grpo/environment/requirements.txt | Pinned Python package dependencies |
sdk/python/jobs/grpo/environment/Dockerfile | Docker image build instructions |
sdk/python/jobs/grpo/datasets/med_mcqa/validation.jsonl | Sample validation split for medical MCQA |
sdk/python/jobs/grpo/datasets/med_mcqa/train.jsonl | Sample training split for medical MCQA |
sdk/python/jobs/grpo/datasets/med_mcqa/test.jsonl | Sample test split for medical MCQA |
sdk/python/jobs/grpo/aml_setup.py | Script to register dataset, model, compute, environment in AML workspace |
Comments suppressed due to low confidence (2)
sdk/python/jobs/grpo/src/grpo_trainer_rewards.py:14
- These new reward functions lack associated unit tests; consider adding tests to verify that
format_reward
,_medmcqa_match_fn
, andaccuracy
produce expected outputs for a variety of inputs.
def format_reward(completions, **kwargs):
sdk/python/jobs/grpo/datasets/med_mcqa/validation.jsonl:1
- The dataset prompt contains a typo (
Murphy&;s
); it should beMurphy's sign
for clarity and correctness.
"Murphy&;s sign is seen in?"
It would be good to add path in CODEOWNERs file as well |
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
…d.ipynb Co-authored-by: Gayatri Penumetsa <181455625+gpenumetsa-msft@users.noreply.github.com>
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
…d.ipynb Co-authored-by: Gayatri Penumetsa <181455625+gpenumetsa-msft@users.noreply.github.com>
sdk/python/jobs/grpo/launch_grpo_command_job-med-mcqa-commented.ipynb
Outdated
Show resolved
Hide resolved
…d.ipynb Co-authored-by: Gayatri Penumetsa <181455625+gpenumetsa-msft@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In launch_grpo_command_job-med-mcqa-commented.ipynb
, when describing the setup (data/model) Section 1- add some more details. Here is the snippet you can use:
The Azure Machine Learning (AML) **setup process is encapsulated** into a script that provisions all required resources in the workspace. \
By the end of the setup, the AML workspace will be fully configured with the below resources:
- **Dataset** : [MedMCQA](https://medmcqa.github.io): A Large-scale Multi-Subject Multi-Choice Dataset for Medical domain Question Answering. We use a modified version of the MedMCQA dataset, restricting our experiments to question/answer pairs having only a single correct answer. The modified dataset used in the demo can be found in `datasets/med_mcqa`
- **Model** : [Qwen2_5-7B-Instruct_base](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
- **Compute Cluster**: STANDARD_ND96ISR_H100_V5 cluster with at least 2 nodes
- **Environment**: Is designed for GRPO specific large-scale, distributed training and inference of reasoning models using Azure Machine Learning, TRL, DeepSpeed, vLLM, and LoRA.
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed and suggested changes. LGTM.
Description
Checklist