Skip to content

Commit eae46dd

Browse files
authored
[GRPO][Eval] Add letter counting eval (#1574)
1 parent 9edcbd9 commit eae46dd

File tree

11 files changed

+272
-16
lines changed

11 files changed

+272
-16
lines changed

configs/examples/grpo_tldr/gcp_job.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ envs:
3131

3232
setup: |
3333
set -e
34-
pip install uv && uv pip install oumi[gpu] vllm
34+
pip install uv && uv pip install oumi[gpu] "vllm>=0.7.3,<0.8.0"
3535
pip install -U flash-attn --no-build-isolation
3636
3737
run: |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Config to eval an LLM's ability to count letters in words.
2+
#
3+
# Requirements:
4+
# - Run `pip install vllm`
5+
# - Log into HF: `huggingface-cli login`
6+
#
7+
# Usage:
8+
# oumi evaluate -c oumi://configs/examples/letter_counting/evaluation/eval.yaml
9+
#
10+
# See Also:
11+
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html
12+
# - Config class: oumi.core.configs.EvaluationConfig
13+
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/evaluation_config.py
14+
# - Other eval configs: configs/**/evaluation/
15+
16+
model:
17+
model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
18+
model_max_length: 131072
19+
torch_dtype_str: "bfloat16"
20+
attn_implementation: "sdpa"
21+
trust_remote_code: True
22+
23+
generation:
24+
max_new_tokens: 2048
25+
# This isn't used by vLLM, but is used for the NATIVE inference engine.
26+
batch_size: 4
27+
28+
tasks:
29+
- evaluation_backend: custom
30+
task_name: count_letters
31+
num_samples: 1000
32+
33+
inference_engine: VLLM # Can also use NATIVE if not running on GPUs
34+
35+
output_dir: "output/letter_counting/evaluation"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Job config to eval an LLM's ability to count letters in words.
2+
#
3+
# Requirements:
4+
# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup
5+
# - Log into HF: `huggingface-cli login`
6+
#
7+
# Usage:
8+
# oumi launch up -c oumi://configs/examples/letter_counting/evaluation/gcp_job.yaml --cluster letter-counting-eval
9+
#
10+
# See Also:
11+
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html
12+
# - Config class: oumi.core.configs.JobConfig
13+
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py
14+
# - Other job configs: configs/**/*job.yaml
15+
16+
name: letter-counting-eval
17+
18+
resources:
19+
cloud: gcp
20+
accelerators: "A100"
21+
use_spot: false
22+
23+
working_dir: .
24+
25+
file_mounts:
26+
~/.netrc: ~/.netrc # WandB credentials
27+
~/.cache/huggingface/token: ~/.cache/huggingface/token # HF credentials
28+
29+
envs:
30+
# NOTE: For SFT, update this to point to your model checkpoint.
31+
# NOTE: For LoRA, instead update this to point to your LoRA adapter.
32+
# The base model will be inferred automatically.
33+
MODEL_CHECKPOINT_DIR: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
34+
WANDB_PROJECT: oumi-eval
35+
OUMI_RUN_NAME: letter-counting.eval
36+
37+
setup: |
38+
set -e
39+
pip install uv && uv pip install oumi[gpu,evaluation] "vllm>=0.7.3,<0.8.0"
40+
41+
run: |
42+
set -e # Exit if any command failed.
43+
source ./configs/examples/misc/sky_init.sh
44+
45+
if test ${OUMI_NUM_NODES} -ne 1; then
46+
echo "LM Harness supports max 1 node. Actual: ${OUMI_NUM_NODES} nodes."
47+
exit 1
48+
fi
49+
50+
echo "Starting evaluation for ${MODEL_CHECKPOINT_DIR} ..."
51+
set -x
52+
53+
accelerate launch \
54+
-m oumi evaluate \
55+
-c oumi://configs/examples/letter_counting/evaluation/eval.yaml \
56+
--run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \
57+
--model.model_name "${MODEL_CHECKPOINT_DIR}"
58+
59+
echo "Node ${SKYPILOT_NODE_RANK} is all done!"

configs/examples/letter_counting/grpo/gcp_job.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Requirements:
44
# - Set up SkyPilot GCP: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html#setup
55
# - Log into WandB (`wandb login`) or disable `enable_wandb`
6+
# - Log into HF: `huggingface-cli login`
67
#
78
# Usage:
89
# oumi launch up -c oumi://configs/examples/letter_counting/grpo/gcp_job.yaml --cluster letter-counting-grpo
@@ -33,7 +34,7 @@ envs:
3334
setup: |
3435
set -e
3536
# vLLM needed for vLLM-powered generation during GRPO training.
36-
pip install uv && uv pip install oumi[gpu] vllm
37+
pip install uv && uv pip install oumi[gpu] "vllm>=0.7.3,<0.8.0"
3738
pip install -U flash-attn --no-build-isolation
3839
3940
run: |

configs/examples/letter_counting/grpo/train.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# Requirements:
44
# - Log into WandB (`wandb login`) or disable `enable_wandb`
5+
# - Log into HF: `huggingface-cli login`
56
#
67
# Usage:
78
# oumi train -c oumi://configs/examples/letter_counting/grpo/train.yaml

configs/recipes/phi3/evaluation/eval.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ model:
1616
model_name: "microsoft/Phi-3-mini-4k-instruct"
1717
trust_remote_code: True
1818
torch_dtype_str: "bfloat16"
19-
shard_for_eval: True
2019

2120
# HuggingFace Leaderboard V1
2221
tasks:

src/oumi/core/datasets/base_grpo_dataset.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from abc import abstractmethod
1516
from typing import Optional, Union
1617

1718
import pandas as pd
1819
from typing_extensions import override
1920

2021
from oumi.core.datasets.base_map_dataset import BaseMapDataset
21-
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
22+
from oumi.core.types.conversation import Conversation
2223

2324
_PROMPT_KEY = "prompt"
2425
_COMPLETION_KEY = "completion"
@@ -37,8 +38,6 @@ def __init__(
3738
dataset_name: Optional[str] = None,
3839
dataset_path: Optional[str] = None,
3940
split: Optional[str] = None,
40-
tokenizer: Optional[BaseTokenizer] = None,
41-
return_tensors: bool = False,
4241
**kwargs,
4342
) -> None:
4443
"""Initializes a new instance of the BaseExperimentalGrpoDataset class."""
@@ -49,14 +48,6 @@ def __init__(
4948
**kwargs,
5049
)
5150

52-
if return_tensors:
53-
raise NotImplementedError(
54-
"return_tensors=True is not implemented for this class"
55-
)
56-
57-
self._tokenizer = tokenizer
58-
self._return_tensors = return_tensors
59-
6051
self._data = self._load_data()
6152

6253
@staticmethod
@@ -65,7 +56,7 @@ def _process_text_value(s: str) -> str:
6556
# of text values. Let's strip them.
6657
return s.strip() if s else ""
6758

68-
def transform_grpo_example(self, example: Union[dict, pd.Series]) -> dict:
59+
def _transform_grpo_example(self, example: Union[dict, pd.Series]) -> dict:
6960
"""Validate and transform the GRPO sample into Python `dict`."""
7061
for required_key in (_PROMPT_KEY, _COMPLETION_KEY):
7162
if required_key not in example:
@@ -95,4 +86,37 @@ def transform_grpo_example(self, example: Union[dict, pd.Series]) -> dict:
9586
@override
9687
def transform(self, sample: pd.Series) -> dict:
9788
"""Validate and transform the sample into Python `dict`."""
98-
return self.transform_grpo_example(sample)
89+
return self._transform_grpo_example(sample)
90+
91+
def conversation(self, idx: int) -> Conversation:
92+
"""Returns the conversation at the specified index.
93+
94+
Args:
95+
idx (int): The index of the conversation to retrieve.
96+
97+
Returns:
98+
str: The conversation at the specified index.
99+
"""
100+
sample = self.raw(idx)
101+
return self.transform_conversation(sample)
102+
103+
def conversations(self) -> list[Conversation]:
104+
"""Returns a list of all conversations."""
105+
indexes = range(len(self))
106+
return [self.conversation(index) for index in indexes]
107+
108+
#
109+
# Abstract Methods
110+
#
111+
@abstractmethod
112+
def transform_conversation(self, sample: Union[dict, pd.Series]) -> Conversation:
113+
"""Converts the input sample to a Conversation.
114+
115+
Args:
116+
sample (Union[dict, pd.Series]): The input example.
117+
118+
Returns:
119+
Conversation: The resulting conversation.
120+
121+
"""
122+
raise NotImplementedError

src/oumi/core/evaluation/evaluator.py

+2
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def _get_custom_evaluation_fn(task_name: Optional[str]) -> Callable:
238238
"task name, which should be corresponding to a registered evaluation "
239239
"function, using the decorator `@register_evaluation_function`."
240240
)
241+
# Import to ensure custom evaluation functions are added to REGISTRY.
242+
import oumi.evaluation.registry as evaluation_registry # noqa: F401
241243

242244
if evaluation_fn := REGISTRY.get_evaluation_function(task_name):
243245
return evaluation_fn

src/oumi/datasets/grpo/letter_count.py

+28
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
from oumi.core.datasets.base_grpo_dataset import BaseExperimentalGrpoDataset
1919
from oumi.core.registry import register_dataset
20+
from oumi.core.types.conversation import Conversation
21+
22+
_SYSTEM_PROMPT = (
23+
"Your final answer should be written as digits and formatted as "
24+
r'"\boxed{your_answer}". For example, if the answer is 42, '
25+
r'make sure to output "\boxed{42}".'
26+
)
2027

2128

2229
@register_dataset("oumi-ai/oumi-letter-count")
@@ -47,7 +54,28 @@ class LetterCountGrpoDataset(BaseExperimentalGrpoDataset):
4754
@override
4855
def transform(self, sample: pd.Series) -> dict:
4956
"""Validate and transform the sample into Python `dict`."""
57+
# TODO: OPE-1122: Add system prompt to training.
58+
# OPE-1158 seems to affect this, as the type of the input isn't consistent.
5059
return {
5160
"prompt": sample["messages"],
5261
"letter_count": sample["metadata"]["letter_count_integer"],
5362
}
63+
64+
@override
65+
def transform_conversation(self, sample: pd.Series) -> Conversation:
66+
"""Converts the input sample to a Conversation.
67+
68+
Args:
69+
sample (dict): The input example.
70+
71+
Returns:
72+
Conversation: The resulting conversation.
73+
74+
"""
75+
# Example is already in conversation format and only needs light processing.
76+
sample_dict = sample.to_dict()
77+
# Convert messages from np.ndarray to list.
78+
sample_dict["messages"] = sample_dict["messages"].tolist()
79+
# Add system prompt.
80+
sample_dict["messages"].append({"content": _SYSTEM_PROMPT, "role": "system"})
81+
return Conversation.from_dict(sample_dict)
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 - Oumi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Evaluation registry module."""
16+
17+
from oumi.evaluation.registry.count_letters_task import count_letters
18+
19+
__all__ = [
20+
"count_letters",
21+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2025 - Oumi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
from typing import Any, Optional
17+
18+
from oumi.core.configs.params.evaluation_params import EvaluationTaskParams
19+
from oumi.core.inference.base_inference_engine import BaseInferenceEngine
20+
from oumi.core.registry import register_evaluation_function
21+
from oumi.datasets.grpo.letter_count import LetterCountGrpoDataset
22+
from oumi.utils.logging import logger
23+
24+
25+
def _extract_prediction(response: str) -> Optional[int]:
26+
r"""Returns the numeric answer extracted from `\boxed{...}`, or None otherwise."""
27+
regex_result = re.findall(r"\\boxed\{(\d+)\}", response)
28+
if not regex_result or len(regex_result) != 1:
29+
return None
30+
number_str = regex_result[0]
31+
# Except clause shouldn't trigger because the regex should only find ints.
32+
try:
33+
return int(number_str)
34+
except ValueError:
35+
return None
36+
37+
38+
@register_evaluation_function("count_letters")
39+
def count_letters(
40+
task_params: EvaluationTaskParams,
41+
inference_engine: BaseInferenceEngine,
42+
) -> dict[str, Any]:
43+
"""Custom evaluation function registered as `count_letters`."""
44+
dataset = LetterCountGrpoDataset(split="test")
45+
# TODO: OPE-1155: Add support for using Oumi dataset code to create the dataset.
46+
# dataset = build_dataset("oumi-ai/oumi-letter-count", tokenizer=None, sample_count=10) # noqa: E501
47+
# dataset = build_dataset("oumi-ai/berrybench-v0.1.0", tokenizer=None, sample_count=10) # noqa: E501
48+
num_samples = task_params.num_samples
49+
if num_samples is None:
50+
num_samples = len(dataset)
51+
input_conversations = [dataset.conversation(i) for i in range(num_samples)]
52+
conversations = inference_engine.infer(input_conversations)
53+
logger.info(f"Finished inference on {len(conversations)} conversations!")
54+
if len(conversations) > 0:
55+
logger.info(f"Sample conversation: {conversations[0]}")
56+
57+
count = 0 # The number of examples with correct answers extracted.
58+
total = 0 # All examples.
59+
valid_count = 0 # The number of examples with valid answers extracted.
60+
for i, conversation in enumerate(conversations):
61+
total += 1
62+
# Grab the model's response
63+
response = conversation.last_message()
64+
# Ignore cases where model didn't respond or it's a multimodal response.
65+
# For now, we focus on text-only responses.
66+
if not response or not isinstance(response.content, str):
67+
continue
68+
# Count the example as correct if the extracted prediction is correct.
69+
prediction = _extract_prediction(response.content)
70+
if prediction is None:
71+
continue
72+
valid_count += 1
73+
if prediction == conversation.metadata["letter_count_integer"]:
74+
count += 1
75+
76+
return {
77+
# Accuracy across all examples.
78+
"accuracy": count / total,
79+
# Accuracy when only counting examples with properly extracted answers.
80+
"properly_extracted_accuracy": count / valid_count,
81+
"num_samples": num_samples,
82+
# These three values sum up to num_samples.
83+
"num_correct_answers": count,
84+
"num_incorrect_answers": valid_count - count,
85+
"num_invalid_answers": total - valid_count,
86+
}

0 commit comments

Comments
 (0)