Skip to content

Commit 0a64295

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - support setting autorater generation config for predefined rubric metrics
PiperOrigin-RevId: 833487047
1 parent c8a5f96 commit 0a64295

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai import types
19+
from google.genai import types as genai_types
1920
import pandas as pd
2021

2122

@@ -60,6 +61,50 @@ def test_evaluation_result(client):
6061
assert case_result.response_candidate_results is not None
6162

6263

64+
def test_evaluation_result_with_autorater_config(client):
65+
"""Tests that evaluate() produces a correctly structured EvaluationResult."""
66+
prompts_df = pd.DataFrame(
67+
{
68+
"prompt": ["Explain the concept of machine learning in simple terms."],
69+
"response": [
70+
"Machine learning is a type of artificial intelligence that allows"
71+
" computers to learn from data without being explicitly programmed."
72+
],
73+
}
74+
)
75+
76+
eval_dataset = types.EvaluationDataset(
77+
eval_dataset_df=prompts_df,
78+
candidate_name="gemini-2.5-flash",
79+
)
80+
81+
predefined_metric_with_autorater_config = types.RubricMetric.GENERAL_QUALITY(
82+
judge_model_generation_config=genai_types.GenerationConfig(
83+
temperature=0.1,
84+
max_output_tokens=1024,
85+
)
86+
)
87+
88+
evaluation_result = client.evals.evaluate(
89+
dataset=eval_dataset,
90+
metrics=[predefined_metric_with_autorater_config],
91+
)
92+
93+
assert isinstance(evaluation_result, types.EvaluationResult)
94+
95+
assert evaluation_result.summary_metrics is not None
96+
for summary in evaluation_result.summary_metrics:
97+
assert isinstance(summary, types.AggregatedMetricResult)
98+
assert summary.metric_name == "general_quality_v1"
99+
assert summary.mean_score is not None
100+
101+
assert evaluation_result.eval_case_results is not None
102+
for case_result in evaluation_result.eval_case_results:
103+
assert isinstance(case_result, types.EvalCaseResult)
104+
assert case_result.eval_case_index is not None
105+
assert case_result.response_candidate_results is not None
106+
107+
63108
def test_multi_turn_predefined_metric(client):
64109
"""Tests that evaluate works with multi-turn predefined metrics."""
65110
prompts_data = {

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,10 @@ def _add_autorater_config(self, payload: dict[str, Any]) -> None:
620620
autorater_config = {}
621621
if self.metric.judge_model:
622622
autorater_config["autorater_model"] = self.metric.judge_model
623+
if self.metric.judge_model_generation_config:
624+
autorater_config["generation_config"] = (
625+
self.metric.judge_model_generation_config
626+
)
623627
if self.metric.judge_model_sampling_count:
624628
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count # type: ignore[assignment]
625629

@@ -986,10 +990,25 @@ def _build_request_payload(
986990
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
987991
)
988992

989-
return {
993+
request_payload = {
990994
"instance": instance_payload,
991995
}
992996

997+
autorater_config = {}
998+
if self.metric.judge_model:
999+
autorater_config["autorater_model"] = self.metric.judge_model
1000+
if self.metric.judge_model_generation_config:
1001+
autorater_config["generation_config"] = (
1002+
self.metric.judge_model_generation_config
1003+
)
1004+
if self.metric.judge_model_sampling_count:
1005+
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count
1006+
if autorater_config:
1007+
request_payload["autorater_config"] = genai_types.AutoraterConfig(
1008+
**autorater_config
1009+
)
1010+
return request_payload
1011+
9931012
@override
9941013
def get_metric_result(
9951014
self, eval_case: types.EvalCase, response_index: int
@@ -1001,7 +1020,9 @@ def get_metric_result(
10011020
for attempt in range(_MAX_RETRIES):
10021021
try:
10031022
api_response = self.module._evaluate_instances(
1004-
metrics=[self.metric], instance=payload.get("instance")
1023+
metrics=[self.metric],
1024+
instance=payload.get("instance"),
1025+
autorater_config=payload.get("autorater_config"),
10051026
)
10061027
break
10071028
except genai_errors.ClientError as e:

vertexai/_genai/types/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,6 +2622,10 @@ class Metric(_common.BaseModel):
26222622
judge_model: Optional[str] = Field(
26232623
default=None, description="""The judge model for the metric."""
26242624
)
2625+
judge_model_generation_config: Optional[genai_types.GenerationConfig] = Field(
2626+
default=None,
2627+
description="""The generation config for the judge LLM (temperature, top_k, top_p, etc).""",
2628+
)
26252629
judge_model_sampling_count: Optional[int] = Field(
26262630
default=None, description="""The sampling count for the judge model."""
26272631
)
@@ -2825,6 +2829,9 @@ class MetricDict(TypedDict, total=False):
28252829
judge_model: Optional[str]
28262830
"""The judge model for the metric."""
28272831

2832+
judge_model_generation_config: Optional[genai_types.GenerationConfigDict]
2833+
"""The generation config for the judge LLM (temperature, top_k, top_p, etc)."""
2834+
28282835
judge_model_sampling_count: Optional[int]
28292836
"""The sampling count for the judge model."""
28302837

0 commit comments

Comments
 (0)