Skip to content

Commit a0cf97e

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Some small infra fixes to the gepa demo colab
PiperOrigin-RevId: 829240716
1 parent d118479 commit a0cf97e

File tree

5 files changed

+76
-50
lines changed

5 files changed

+76
-50
lines changed

contributing/samples/gepa/experiment.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,11 @@
2626
import random
2727
import traceback
2828
from typing import Any
29-
from typing import Callable
3029
from typing import TypedDict
3130

3231
import gepa
3332
from gepa.core.adapter import EvaluationBatch
3433
from gepa.core.adapter import GEPAAdapter
35-
from google.genai import types
3634
from litellm import provider_list
3735
import rater_lib
3836
from retry import retry
@@ -46,20 +44,7 @@
4644
from tau_bench.types import RunConfig
4745
import tau_bench_agent as tau_bench_agent_lib
4846

49-
from google import genai
50-
51-
52-
class FilterInferenceWarnings(logging.Filter):
53-
"""Filters out Vertex inference warning about non-text parts in response."""
54-
55-
def filter(self, record: logging.LogRecord) -> bool:
56-
"""Filters out Vertex inference warning about non-text parts in response."""
57-
if record.levelname != 'WARNING':
58-
return True
59-
message_identifier = record.getMessage()
60-
return not message_identifier.startswith(
61-
'Warning: there are non-text parts in the response:'
62-
)
47+
import utils
6348

6449

6550
def run_tau_bench_rollouts(
@@ -494,26 +479,6 @@ def _get_datasets(
494479
)
495480

496481

497-
def reflection_inference_fn(model: str) -> Callable[[str], str]:
498-
"""Returns an inference function on VertexAI based on provided model."""
499-
client = genai.Client()
500-
501-
@retry(tries=3, delay=10, backoff=2)
502-
def _fn(prompt):
503-
return client.models.generate_content(
504-
model=model,
505-
contents=prompt,
506-
config=types.GenerateContentConfig(
507-
candidate_count=1,
508-
thinking_config=types.ThinkingConfig(
509-
include_thoughts=True, thinking_budget=-1
510-
),
511-
),
512-
).text
513-
514-
return _fn
515-
516-
517482
SEED_SYSTEM_INSTRUCTION = (
518483
'you are a customer support agent helping customers resolve their '
519484
'issues by using the right tools'
@@ -618,7 +583,7 @@ def run_gepa(
618583
task_lm=None, # this must be None when a custom adapter is used
619584
adapter=tau_bench_adapter,
620585
max_metric_calls=config.max_metric_calls,
621-
reflection_lm=reflection_inference_fn(config.reflection_model),
586+
reflection_lm=utils.reflection_inference_fn(config.reflection_model),
622587
reflection_minibatch_size=config.reflection_minibatch_size,
623588
run_dir=output_dir,
624589
)

contributing/samples/gepa/gepa_tau_bench.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
"\n",
9999
"import experiment as experiment_lib\n",
100100
"from google.genai import types\n",
101+
"import utils\n",
101102
"\n",
102103
"\n",
103104
"# @markdown ### ☁️ Configure Vertex AI Access\n",
@@ -140,7 +141,7 @@
140141
"\n",
141142
"# Set a logging verbosity suited for this experiment. See\n",
142143
"# https://github.com/google/adk-python/issues/1852 for context\n",
143-
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
144+
"types.logger.addFilter(utils.FilterInferenceWarnings())"
144145
]
145146
},
146147
{

contributing/samples/gepa/run_experiment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import experiment
2727
from google.genai import types
2828

29+
import utils
30+
2931
_OUTPUT_DIR = flags.DEFINE_string(
3032
'output_dir',
3133
None,
@@ -104,7 +106,7 @@ def main(argv: Sequence[str]) -> None:
104106
for logger in loggers:
105107
logger.setLevel(logging.WARNING)
106108

107-
types.logger.addFilter(experiment.FilterInferenceWarnings())
109+
types.logger.addFilter(utils.FilterInferenceWarnings())
108110
output_dir = os.path.join(
109111
_OUTPUT_DIR.value, datetime.now().strftime('%Y%m%d%H%M%S%f')
110112
)

contributing/samples/gepa/utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Defines utility for GEPA experiments."""
16+
17+
import logging
18+
from typing import Callable
19+
20+
from google.genai import types
21+
from retry import retry
22+
23+
from google import genai
24+
25+
26+
class FilterInferenceWarnings(logging.Filter):
27+
"""Filters out Vertex inference warning about non-text parts in response."""
28+
29+
def filter(self, record: logging.LogRecord) -> bool:
30+
"""Filters out Vertex inference warning about non-text parts in response."""
31+
if record.levelname != 'WARNING':
32+
return True
33+
message_identifier = record.getMessage()
34+
return not message_identifier.startswith(
35+
'Warning: there are non-text parts in the response:'
36+
)
37+
38+
39+
def reflection_inference_fn(model: str) -> Callable[[str], str]:
40+
"""Returns an inference function on VertexAI based on provided model."""
41+
client = genai.Client()
42+
43+
@retry(tries=3, delay=10, backoff=2)
44+
def _fn(prompt):
45+
return client.models.generate_content(
46+
model=model,
47+
contents=prompt,
48+
config=types.GenerateContentConfig(
49+
candidate_count=1,
50+
thinking_config=types.ThinkingConfig(
51+
include_thoughts=True, thinking_budget=-1
52+
),
53+
),
54+
).text
55+
56+
return _fn

contributing/samples/gepa/voter_agent/gepa.ipynb

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"#@title Install GEPA\n",
6666
"!git clone https://github.com/google/adk-python.git\n",
6767
"!pip install gepa --quiet\n",
68+
"!pip install litellm --quiet\n",
6869
"!pip install retry --quiet"
6970
]
7071
},
@@ -112,7 +113,7 @@
112113
"import os\n",
113114
"\n",
114115
"from google.genai import types\n",
115-
"import experiment as experiment_lib\n",
116+
"import utils\n",
116117
"\n",
117118
"\n",
118119
"# @markdown ### ☁️ Configure Vertex AI Access\n",
@@ -139,7 +140,7 @@
139140
"for logger in loggers:\n",
140141
" logger.setLevel(logging.WARNING)\n",
141142
"\n",
142-
"types.logger.addFilter(experiment_lib.FilterInferenceWarnings())"
143+
"types.logger.addFilter(utils.FilterInferenceWarnings())"
143144
]
144145
},
145146
{
@@ -179,7 +180,7 @@
179180
"from google.adk.agents import base_agent\n",
180181
"from google.adk.agents import llm_agent\n",
181182
"\n",
182-
"import tools\n",
183+
"from voter_agent import tools\n",
183184
"\n",
184185
"\n",
185186
"# @markdown ### 🧠 Configure our ADK LLM Agent\n",
@@ -368,7 +369,10 @@
368369
" return [line.strip() for line in open(filename) if line.strip()]\n",
369370
"\n",
370371
"\n",
371-
"voter_data = _read_prompts('prompts.txt')\n",
372+
"_AGENT_DIR = 'adk-python/contributing/samples/gepa/voter_agent'\n",
373+
"\n",
374+
"\n",
375+
"voter_data = _read_prompts(f'{_AGENT_DIR}/prompts.txt')\n",
372376
"voter_data"
373377
]
374378
},
@@ -392,7 +396,8 @@
392396
"execution_count": null,
393397
"metadata": {
394398
"id": "9bHh93RuKVMu",
395-
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027"
399+
"outputId": "489761d4-da39-43ca-cd08-225c44bb3027",
400+
"cellView": "form"
396401
},
397402
"outputs": [
398403
{
@@ -714,7 +719,7 @@
714719
" tool_declarations=TOOLS_DESCRIPTION,\n",
715720
" developer_instructions='',\n",
716721
" rubric=FILTER_RUBRIC,\n",
717-
"\n",
722+
" validation_template_path=f'{_AGENT_DIR}/rubric_validation_template.txt',\n",
718723
")\n",
719724
"\n",
720725
"print(rater(EXAMPLE_TRACE))"
@@ -813,7 +818,7 @@
813818
"source": [
814819
"#@title Let's define an evaluation dataset from sample prompts\n",
815820
"\n",
816-
"eval_dataset = _read_prompts('eval_prompts.txt')\n",
821+
"eval_dataset = _read_prompts(f'{_AGENT_DIR}/eval_prompts.txt')\n",
817822
"eval_dataset"
818823
]
819824
},
@@ -2723,7 +2728,7 @@
27232728
" task_lm=None, # this must be None when a custom adapter is used\n",
27242729
" adapter=adapter,\n",
27252730
" max_metric_calls=MAX_METRIC_CALLS,\n",
2726-
" reflection_lm=experiment_lib.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
2731+
" reflection_lm=utils.reflection_inference_fn(REFLECTION_MODEL_NAME),\n",
27272732
" reflection_minibatch_size=MINI_BATCH_SIZE,\n",
27282733
")\n",
27292734
"list(enumerate(gepa_results.val_aggregate_scores))"
@@ -2955,9 +2960,6 @@
29552960
],
29562961
"metadata": {
29572962
"colab": {
2958-
"collapsed_sections": [
2959-
"rIFFNqYoXp6v"
2960-
],
29612963
"last_runtime": {
29622964
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
29632965
"kind": "private"

0 commit comments

Comments
 (0)