Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Evaluations] Custom evals: Adding support for eval_kwargs #1557

Merged
merged 3 commits into from
Mar 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/oumi/core/evaluation/evaluator.py
Original file line number Diff line number Diff line change
@@ -145,11 +145,12 @@ def evaluate_task(
elif evaluation_backend == EvaluationBackend.CUSTOM:
evaluation_fn = Evaluator._get_custom_evaluation_fn(task_params.task_name)
self._add_inference_engine_if_needed(evaluation_fn, kwargs, config)
custom_kwargs = Evaluator._merge_kwargs(kwargs, task_params.eval_kwargs)

evaluation_result = evaluation_fn(
task_params=task_params,
config=config,
**kwargs,
**custom_kwargs,
)
if not isinstance(evaluation_result, EvaluationResult):
raise ValueError(
@@ -305,6 +306,20 @@ def _get_init_kwargs_for_task_params_class(

return init_kwargs

@staticmethod
def _merge_kwargs(
kwargs_1: dict[str, Any],
kwargs_2: dict[str, Any],
) -> dict[str, Any]:
"""Merges two keyword argument dictionaries."""
if overlapping_keys := kwargs_1.keys() & kwargs_2.keys():
raise ValueError(
"The two keyword argument dictionaries contain overlapping keys: "
f"{overlapping_keys}. Please ensure that the keys in the following "
f"dictionaries are unique: `{kwargs_1.keys()}` and `{kwargs_2.keys()}`"
)
return kwargs_1 | kwargs_2

def _add_inference_engine_if_needed(
self,
evaluation_function: Callable,
58 changes: 55 additions & 3 deletions tests/unit/core/evaluation/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -149,6 +149,7 @@ def test_evaluate_custom_task(
task_params = EvaluationTaskParams(
task_name="evaluation_fn_reg_name",
evaluation_backend=EvaluationBackend.CUSTOM.value,
eval_kwargs={"optional_param_2": "optional_param_2_value"},
)
evaluation_config = EvaluationConfig(
tasks=[task_params],
@@ -160,11 +161,13 @@ def test_evaluate_custom_task(
def evaluation_fn(
task_params: EvaluationTaskParams,
config: EvaluationConfig,
optional_param: str,
optional_param_1: str,
optional_param_2: str,
) -> EvaluationResult:
assert task_params.evaluation_backend == EvaluationBackend.CUSTOM.value
assert task_params.task_name == "evaluation_fn_reg_name"
assert optional_param == "optional_param_value"
assert optional_param_1 == "optional_param_1_value"
assert optional_param_2 == "optional_param_2_value"
return EvaluationResult(
task_name=task_params.task_name,
task_result={"test_metric": 1.0},
@@ -179,7 +182,7 @@ def evaluation_fn(
# Run the test.
evaluator = Evaluator()
result = evaluator.evaluate(
evaluation_config, optional_param="optional_param_value"
evaluation_config, optional_param_1="optional_param_1_value"
)

# Check the results.
@@ -397,6 +400,55 @@ def evaluation_fn(task_params, config, inference_engine):
mock_get_evaluation_function.assert_called_once()


@patch("oumi.core.evaluation.evaluator.REGISTRY.get_evaluation_function")
@patch("oumi.core.evaluation.evaluator.check_prerequisites")
@patch("oumi.core.evaluation.evaluator.save_evaluation_output")
@patch("oumi.core.evaluation.evaluator.build_inference_engine")
def test_evaluate_custom_task_duplicate_optional_param(
mock_build_inference_engine,
mock_save_evaluation_output,
mock_check_prerequisites,
mock_get_evaluation_function,
):
# Inputs.
task_params = EvaluationTaskParams(
task_name="evaluation_fn_reg_name",
evaluation_backend=EvaluationBackend.CUSTOM.value,
eval_kwargs={"optional_param": "value"},
)
evaluation_config = EvaluationConfig(tasks=[task_params])

def evaluation_fn(task_params, config):
pass

# Mocks.
mock_build_inference_engine.return_value = MagicMock()
mock_save_evaluation_output.return_value = None
mock_check_prerequisites.return_value = None
mock_get_evaluation_function.return_value = evaluation_fn

# Run the test.
evaluator = Evaluator()

with pytest.raises(
ValueError,
match=(
r"^The two keyword argument dictionaries contain overlapping keys: "
"{'optional_param'}."
),
):
_ = evaluator.evaluate(
evaluation_config,
optional_param="value", # NOT allowed, already set in `eval_kwargs`.
)

# Check the results.
mock_build_inference_engine.assert_not_called()
mock_save_evaluation_output.assert_not_called()
mock_check_prerequisites.assert_called_once()
mock_get_evaluation_function.assert_called_once()


@patch("oumi.core.evaluation.evaluator.evaluate_lm_harness")
@patch("oumi.core.evaluation.evaluator.evaluate_alpaca_eval")
@patch("oumi.core.evaluation.evaluator.check_prerequisites")