Skip to content

Commit ce35218

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add metrics to create_evaluation_run method in Vertex AI GenAI SDK evals
PiperOrigin-RevId: 822249192
1 parent b05e5b3 commit ce35218

File tree

4 files changed

+440
-135
lines changed

4 files changed

+440
-135
lines changed

tests/unit/vertexai/genai/replays/test_create_evaluation_run.py

Lines changed: 101 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,96 @@
1919
from google.genai import types as genai_types
2020
import pytest
2121

22-
23-
def test_create_eval_run_data_source_evaluation_set(client):
24-
"""Tests that create_evaluation_run() creates a correctly structured EvaluationRun."""
25-
client._api_client._http_options.api_version = "v1beta1"
26-
tool = genai_types.Tool(
27-
function_declarations=[
28-
genai_types.FunctionDeclaration(
29-
name="get_weather",
30-
description="Get weather in a location",
31-
parameters={
32-
"type": "object",
33-
"properties": {"location": {"type": "string"}},
34-
},
22+
GCS_DEST = "gs://lakeyk-test-limited/eval_run_output"
23+
UNIVERSAL_AR_METRIC = types.EvaluationRunMetric(
24+
metric="universal_ar_v1",
25+
metric_config=types.UnifiedMetric(
26+
predefined_metric_spec=types.PredefinedMetricSpec(
27+
metric_spec_name="universal_ar_v1",
28+
)
29+
),
30+
)
31+
FINAL_RESPONSE_QUALITY_METRIC = types.EvaluationRunMetric(
32+
metric="final_response_quality_v1",
33+
metric_config=types.UnifiedMetric(
34+
predefined_metric_spec=types.PredefinedMetricSpec(
35+
metric_spec_name="final_response_quality_v1",
36+
)
37+
),
38+
)
39+
LLM_METRIC = types.EvaluationRunMetric(
40+
metric="llm_metric",
41+
metric_config=types.UnifiedMetric(
42+
llm_based_metric_spec=types.LLMBasedMetricSpec(
43+
metric_prompt_template=(
44+
"\nEvaluate the fluency of the response. Provide a score from 1-5."
3545
)
36-
]
37-
)
38-
evaluation_run = client.evals.create_evaluation_run(
39-
name="test4",
40-
display_name="test4",
41-
dataset=types.EvaluationRunDataSource(
42-
evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
43-
),
44-
agent_info=types.AgentInfo(
45-
name="agent-1",
46-
instruction="agent-1 instruction",
47-
tool_declarations=[tool],
48-
),
49-
dest="gs://lakeyk-limited-bucket/eval_run_output",
50-
)
51-
assert isinstance(evaluation_run, types.EvaluationRun)
52-
assert evaluation_run.display_name == "test4"
53-
assert evaluation_run.state == types.EvaluationRunState.PENDING
54-
assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
55-
assert evaluation_run.data_source.evaluation_set == (
56-
"projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
57-
)
58-
assert evaluation_run.inference_configs[
59-
"agent-1"
60-
] == types.EvaluationRunInferenceConfig(
61-
agent_config=types.EvaluationRunAgentConfig(
62-
developer_instruction=genai_types.Content(
63-
parts=[genai_types.Part(text="agent-1 instruction")]
64-
),
65-
tools=[tool],
6646
)
67-
)
68-
assert evaluation_run.error is None
47+
),
48+
)
49+
50+
51+
# TODO(b/431231205): Re-enable once Unified Metrics are in prod.
52+
# def test_create_eval_run_data_source_evaluation_set(client):
53+
# """Tests that create_evaluation_run() creates a correctly structured EvaluationRun."""
54+
# client._api_client._http_options.base_url = (
55+
# "https://us-central1-autopush-aiplatform.sandbox.googleapis.com/"
56+
# )
57+
# client._api_client._http_options.api_version = "v1beta1"
58+
# tool = genai_types.Tool(
59+
# function_declarations=[
60+
# genai_types.FunctionDeclaration(
61+
# name="get_weather",
62+
# description="Get weather in a location",
63+
# parameters={
64+
# "type": "object",
65+
# "properties": {"location": {"type": "string"}},
66+
# },
67+
# )
68+
# ]
69+
# )
70+
# evaluation_run = client.evals.create_evaluation_run(
71+
# name="test4",
72+
# display_name="test4",
73+
# dataset=types.EvaluationRunDataSource(
74+
# evaluation_set="projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
75+
# ),
76+
# dest=GCS_DEST,
77+
# metrics=[
78+
# UNIVERSAL_AR_METRIC,
79+
# types.RubricMetric.FINAL_RESPONSE_QUALITY,
80+
# LLM_METRIC
81+
# ],
82+
# agent_info=types.AgentInfo(
83+
# name="agent-1",
84+
# instruction="agent-1 instruction",
85+
# tool_declarations=[tool],
86+
# ),
87+
# )
88+
# assert isinstance(evaluation_run, types.EvaluationRun)
89+
# assert evaluation_run.display_name == "test4"
90+
# assert evaluation_run.state == types.EvaluationRunState.PENDING
91+
# assert isinstance(evaluation_run.data_source, types.EvaluationRunDataSource)
92+
# assert evaluation_run.data_source.evaluation_set == (
93+
# "projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800"
94+
# )
95+
# assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
96+
# output_config=genai_types.OutputConfig(
97+
# gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
98+
# ),
99+
# metrics=[UNIVERSAL_AR_METRIC, FINAL_RESPONSE_QUALITY_METRIC, LLM_METRIC],
100+
# )
101+
# assert evaluation_run.inference_configs[
102+
# "agent-1"
103+
# ] == types.EvaluationRunInferenceConfig(
104+
# agent_config=types.EvaluationRunAgentConfig(
105+
# developer_instruction=genai_types.Content(
106+
# parts=[genai_types.Part(text="agent-1 instruction")]
107+
# ),
108+
# tools=[tool],
109+
# )
110+
# )
111+
# assert evaluation_run.error is None
69112

70113

71114
def test_create_eval_run_data_source_bigquery_request_set(client):
@@ -84,7 +127,7 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
84127
},
85128
)
86129
),
87-
dest="gs://lakeyk-limited-bucket/eval_run_output",
130+
dest=GCS_DEST,
88131
)
89132
assert isinstance(evaluation_run, types.EvaluationRun)
90133
assert evaluation_run.display_name == "test5"
@@ -101,6 +144,11 @@ def test_create_eval_run_data_source_bigquery_request_set(client):
101144
},
102145
)
103146
)
147+
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
148+
output_config=genai_types.OutputConfig(
149+
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
150+
),
151+
)
104152
assert evaluation_run.inference_configs is None
105153
assert evaluation_run.error is None
106154

@@ -220,7 +268,7 @@ async def test_create_eval_run_async(client):
220268
},
221269
)
222270
),
223-
dest="gs://lakeyk-limited-bucket/eval_run_output",
271+
dest=GCS_DEST,
224272
)
225273
assert isinstance(evaluation_run, types.EvaluationRun)
226274
assert evaluation_run.display_name == "test8"
@@ -233,6 +281,12 @@ async def test_create_eval_run_async(client):
233281
"checkpoint_2": "checkpoint_2",
234282
},
235283
)
284+
assert evaluation_run.evaluation_config == types.EvaluationRunConfig(
285+
output_config=genai_types.OutputConfig(
286+
gcs_destination=genai_types.GcsDestination(output_uri_prefix=GCS_DEST)
287+
),
288+
)
289+
assert evaluation_run.error is None
236290
assert evaluation_run.inference_configs is None
237291
assert evaluation_run.error is None
238292

vertexai/_genai/_evals_common.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,73 @@ def _resolve_dataset_inputs(
933933
return processed_eval_dataset, num_response_candidates
934934

935935

936+
def _resolve_evaluation_run_metrics(
937+
metrics: list[types.EvaluationRunMetric], api_client: Any
938+
) -> list[types.EvaluationRunMetric]:
939+
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
940+
if not metrics:
941+
return []
942+
resolved_metrics_list = []
943+
for metric_instance in metrics:
944+
if isinstance(metric_instance, types.EvaluationRunMetric):
945+
resolved_metrics_list.append(metric_instance)
946+
elif isinstance(metric_instance, _evals_utils.LazyLoadedPrebuiltMetric):
947+
try:
948+
resolved_metric = metric_instance.resolve(api_client=api_client)
949+
if resolved_metric.name:
950+
resolved_metrics_list.append(
951+
types.EvaluationRunMetric(
952+
metric=resolved_metric.name,
953+
metric_config=types.UnifiedMetric(
954+
predefined_metric_spec=types.PredefinedMetricSpec(
955+
metric_spec_name=resolved_metric.name,
956+
)
957+
),
958+
)
959+
)
960+
except Exception as e:
961+
logger.error(
962+
"Failed to resolve RubricMetric %s@%s: %s",
963+
metric_instance.name,
964+
metric_instance.version,
965+
e,
966+
)
967+
raise
968+
else:
969+
try:
970+
metric_name_str = str(metric_instance)
971+
lazy_metric_instance = getattr(
972+
_evals_utils.RubricMetric, metric_name_str.upper()
973+
)
974+
if isinstance(
975+
lazy_metric_instance, _evals_utils.LazyLoadedPrebuiltMetric
976+
):
977+
resolved_metric = lazy_metric_instance.resolve(
978+
api_client=api_client
979+
)
980+
if resolved_metric.name:
981+
resolved_metrics_list.append(
982+
types.EvaluationRunMetric(
983+
metric=resolved_metric.name,
984+
metric_config=types.UnifiedMetric(
985+
predefined_metric_spec=types.PredefinedMetricSpec(
986+
metric_spec_name=resolved_metric.name,
987+
)
988+
),
989+
)
990+
)
991+
else:
992+
raise TypeError(
993+
f"RubricMetric.{metric_name_str.upper()} cannot be resolved."
994+
)
995+
except AttributeError as exc:
996+
raise TypeError(
997+
"Unsupported metric type or invalid RubricMetric name:"
998+
f" {metric_instance}"
999+
) from exc
1000+
return resolved_metrics_list
1001+
1002+
9361003
def _resolve_metrics(
9371004
metrics: list[types.Metric], api_client: Any
9381005
) -> list[types.Metric]:

vertexai/_genai/evals.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def _EvaluationRun_from_vertex(
230230
getv(from_object, ["evaluationResults"]),
231231
)
232232

233+
if getv(from_object, ["evaluationConfig"]) is not None:
234+
setv(to_object, ["evaluation_config"], getv(from_object, ["evaluationConfig"]))
235+
233236
if getv(from_object, ["inferenceConfigs"]) is not None:
234237
setv(to_object, ["inference_configs"], getv(from_object, ["inferenceConfigs"]))
235238

@@ -460,7 +463,7 @@ def _create_evaluation_run(
460463
name: Optional[str] = None,
461464
display_name: Optional[str] = None,
462465
data_source: types.EvaluationRunDataSourceOrDict,
463-
evaluation_config: genai_types.EvaluationConfigOrDict,
466+
evaluation_config: types.EvaluationRunConfigOrDict,
464467
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
465468
inference_configs: Optional[
466469
dict[str, types.EvaluationRunInferenceConfigOrDict]
@@ -1306,9 +1309,12 @@ def create_evaluation_run(
13061309
self,
13071310
*,
13081311
name: str,
1309-
display_name: Optional[str] = None,
13101312
dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset],
13111313
dest: str,
1314+
display_name: Optional[str] = None,
1315+
metrics: Optional[
1316+
list[types.EvaluationRunMetricOrDict]
1317+
] = None, # TODO: Make required unified metrics available in prod.
13121318
agent_info: Optional[types.AgentInfo] = None,
13131319
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
13141320
) -> types.EvaluationRun:
@@ -1328,7 +1334,12 @@ def create_evaluation_run(
13281334
output_config = genai_types.OutputConfig(
13291335
gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest)
13301336
)
1331-
evaluation_config = genai_types.EvaluationConfig(output_config=output_config)
1337+
resolved_metrics = _evals_common._resolve_evaluation_run_metrics(
1338+
metrics, self._api_client
1339+
)
1340+
evaluation_config = types.EvaluationRunConfig(
1341+
output_config=output_config, metrics=resolved_metrics
1342+
)
13321343
inference_configs = {}
13331344
if agent_info:
13341345
logger.warning(
@@ -1554,7 +1565,7 @@ async def _create_evaluation_run(
15541565
name: Optional[str] = None,
15551566
display_name: Optional[str] = None,
15561567
data_source: types.EvaluationRunDataSourceOrDict,
1557-
evaluation_config: genai_types.EvaluationConfigOrDict,
1568+
evaluation_config: types.EvaluationRunConfigOrDict,
15581569
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
15591570
inference_configs: Optional[
15601571
dict[str, types.EvaluationRunInferenceConfigOrDict]
@@ -2103,9 +2114,12 @@ async def create_evaluation_run(
21032114
self,
21042115
*,
21052116
name: str,
2106-
display_name: Optional[str] = None,
21072117
dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset],
21082118
dest: str,
2119+
display_name: Optional[str] = None,
2120+
metrics: Optional[
2121+
list[types.EvaluationRunMetricOrDict]
2122+
] = None, # TODO: Make required unified metrics available in prod.
21092123
agent_info: Optional[types.AgentInfo] = None,
21102124
config: Optional[types.CreateEvaluationRunConfigOrDict] = None,
21112125
) -> types.EvaluationRun:
@@ -2125,7 +2139,12 @@ async def create_evaluation_run(
21252139
output_config = genai_types.OutputConfig(
21262140
gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest)
21272141
)
2128-
evaluation_config = genai_types.EvaluationConfig(output_config=output_config)
2142+
resolved_metrics = _evals_common._resolve_evaluation_run_metrics(
2143+
metrics, self._api_client
2144+
)
2145+
evaluation_config = types.EvaluationRunConfig(
2146+
output_config=output_config, metrics=resolved_metrics
2147+
)
21292148
inference_configs = {}
21302149
if agent_info:
21312150
logger.warning(

0 commit comments

Comments
 (0)