Skip to content

Commit b6ec343

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - add zero-shot prompt optimizer (streaming): an option to quickly improve or generate system instructions or a single prompt.
PiperOrigin-RevId: 788830431
1 parent 8cfd9ba commit b6ec343

File tree

3 files changed

+836
-692
lines changed

3 files changed

+836
-692
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
import logging
18+
19+
from tests.unit.vertexai.genai.replays import pytest_helper
20+
21+
# from vertexai._genai import types
22+
23+
logger = logging.getLogger("vertexai_genai.promptoptimizer")
24+
logging.basicConfig(encoding="utf-8", level=logging.INFO, force=True)
25+
26+
27+
def test_optimize_prompt(client):
28+
"""Tests the optimize request parameters method."""
29+
30+
client._api_client._http_options.base_url = (
31+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com"
32+
)
33+
test_prompt = "Generate system instructions for analyzing medical articles"
34+
for chunk in client.prompt_optimizer.optimize_prompt(prompt=test_prompt):
35+
logger.info("chunk: %s", chunk)
36+
# logger.info("response: %s", response)
37+
# assert isinstance(response, types.OptimizeResponse)
38+
39+
40+
pytestmark = pytest_helper.setup(
41+
file=__file__,
42+
globals_for_file=globals(),
43+
test_method="prompt_optimizer.optimize_prompt",
44+
)

vertexai/_genai/prompt_optimizer.py

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
import logging
2121
import time
22-
from typing import Any, Optional, Union
22+
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
2323
from urllib.parse import urlencode
2424

2525
from google.genai import _api_module
@@ -39,6 +39,9 @@ def _OptimizeRequestParameters_to_vertex(
3939
parent_object: Optional[dict[str, Any]] = None,
4040
) -> dict[str, Any]:
4141
to_object: dict[str, Any] = {}
42+
if getv(from_object, ["content"]) is not None:
43+
setv(to_object, ["content"], getv(from_object, ["content"]))
44+
4245
if getv(from_object, ["config"]) is not None:
4346
setv(to_object, ["config"], getv(from_object, ["config"]))
4447

@@ -229,6 +232,8 @@ def _OptimizeResponse_from_vertex(
229232
parent_object: Optional[dict[str, Any]] = None,
230233
) -> dict[str, Any]:
231234
to_object: dict[str, Any] = {}
235+
if getv(from_object, ["content"]) is not None:
236+
setv(to_object, ["content"], getv(from_object, ["content"]))
232237

233238
return to_object
234239

@@ -383,12 +388,16 @@ def _CustomJob_from_vertex(
383388
class PromptOptimizer(_api_module.BaseModule):
384389
"""Prompt Optimizer"""
385390

386-
def _optimize_dummy(
387-
self, *, config: Optional[types.OptimizeConfigOrDict] = None
388-
) -> types.OptimizeResponse:
389-
"""Optimize multiple prompts."""
391+
def _optimize_prompt(
392+
self,
393+
*,
394+
content: Optional[types.ContentOrDict] = None,
395+
config: Optional[types.OptimizeConfigOrDict] = None,
396+
) -> Iterator[types.OptimizeResponse]:
397+
"""Optimize a single prompt."""
390398

391399
parameter_model = types._OptimizeRequestParameters(
400+
content=content,
392401
config=config,
393402
)
394403

@@ -399,9 +408,9 @@ def _optimize_dummy(
399408
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
400409
request_url_dict = request_dict.get("_url")
401410
if request_url_dict:
402-
path = ":optimize".format_map(request_url_dict)
411+
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
403412
else:
404-
path = ":optimize"
413+
path = "tuningJobs:optimizePrompt"
405414

406415
query_params = request_dict.get("_query")
407416
if query_params:
@@ -419,19 +428,27 @@ def _optimize_dummy(
419428
request_dict = _common.convert_to_dict(request_dict)
420429
request_dict = _common.encode_unserializable_types(request_dict)
421430

422-
response = self._api_client.request("post", path, request_dict, http_options)
431+
if config is not None and getattr(config, "should_return_http_response", None):
432+
raise ValueError(
433+
"Accessing the raw HTTP response is not supported in streaming"
434+
" methods."
435+
)
423436

424-
response_dict = "" if not response.body else json.loads(response.body)
437+
for response in self._api_client.request_streamed(
438+
"post", path, request_dict, http_options
439+
):
425440

426-
if self._api_client.vertexai:
427-
response_dict = _OptimizeResponse_from_vertex(response_dict)
441+
response_dict = "" if not response.body else json.loads(response.body)
428442

429-
return_value = types.OptimizeResponse._from_response(
430-
response=response_dict, kwargs=parameter_model.model_dump()
431-
)
443+
if self._api_client.vertexai:
444+
response_dict = _OptimizeResponse_from_vertex(response_dict)
432445

433-
self._api_client._verify_response(return_value)
434-
return return_value
446+
return_value = types.OptimizeResponse._from_response(
447+
response=response_dict, kwargs=parameter_model.model_dump()
448+
)
449+
450+
self._api_client._verify_response(return_value)
451+
yield return_value
435452

436453
def _create_custom_job_resource(
437454
self,
@@ -660,16 +677,45 @@ def optimize(
660677
job = self._wait_for_completion(job_id)
661678
return job
662679

680+
def optimize_prompt(
681+
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
682+
) -> Iterator[types.OptimizeResponse]:
683+
"""Makes an API request to _optimize_prompt and yields the optimized prompt in chunks."""
684+
if config is not None:
685+
raise ValueError(
686+
"Currently, config is not supported for a single prompt"
687+
" optimization."
688+
)
689+
690+
prompt = types.Content(parts=[types.Part(text=prompt)], role="user")
691+
# response = self._optimize_prompt(content=prompt)
692+
# logger.info(type(response))
693+
# logger.info(response)
694+
695+
# for chunk in response:
696+
# yield chunk
697+
for chunk in self._optimize_prompt(content=prompt):
698+
# logger.info(chunk)
699+
# if chunk.content and chunk.content.parts[0].text:
700+
# logger.info('chunk has content text %s', chunk.content.parts[0].text)
701+
# if chunk.parts[0]:
702+
# logger.info('chunk has parts %s', chunk.parts[0])
703+
yield chunk
704+
663705

664706
class AsyncPromptOptimizer(_api_module.BaseModule):
665707
"""Prompt Optimizer"""
666708

667-
async def _optimize_dummy(
668-
self, *, config: Optional[types.OptimizeConfigOrDict] = None
669-
) -> types.OptimizeResponse:
670-
"""Optimize multiple prompts."""
709+
async def _optimize_prompt(
710+
self,
711+
*,
712+
content: Optional[types.ContentOrDict] = None,
713+
config: Optional[types.OptimizeConfigOrDict] = None,
714+
) -> Awaitable[AsyncIterator[types.OptimizeResponse]]:
715+
"""Optimize a single prompt."""
671716

672717
parameter_model = types._OptimizeRequestParameters(
718+
content=content,
673719
config=config,
674720
)
675721

@@ -680,9 +726,9 @@ async def _optimize_dummy(
680726
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
681727
request_url_dict = request_dict.get("_url")
682728
if request_url_dict:
683-
path = ":optimize".format_map(request_url_dict)
729+
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
684730
else:
685-
path = ":optimize"
731+
path = "tuningJobs:optimizePrompt"
686732

687733
query_params = request_dict.get("_query")
688734
if query_params:
@@ -700,21 +746,32 @@ async def _optimize_dummy(
700746
request_dict = _common.convert_to_dict(request_dict)
701747
request_dict = _common.encode_unserializable_types(request_dict)
702748

703-
response = await self._api_client.async_request(
749+
if config is not None and getattr(config, "should_return_http_response", None):
750+
raise ValueError(
751+
"Accessing the raw HTTP response is not supported in streaming"
752+
" methods."
753+
)
754+
755+
response_stream = await self._api_client.async_request_streamed(
704756
"post", path, request_dict, http_options
705757
)
706758

707-
response_dict = "" if not response.body else json.loads(response.body)
759+
async def async_generator(): # type: ignore[no-untyped-def]
760+
async for response in response_stream:
708761

709-
if self._api_client.vertexai:
710-
response_dict = _OptimizeResponse_from_vertex(response_dict)
762+
response_dict = "" if not response.body else json.loads(response.body)
711763

712-
return_value = types.OptimizeResponse._from_response(
713-
response=response_dict, kwargs=parameter_model.model_dump()
714-
)
764+
if self._api_client.vertexai:
765+
response_dict = _OptimizeResponse_from_vertex(response_dict)
715766

716-
self._api_client._verify_response(return_value)
717-
return return_value
767+
return_value = types.OptimizeResponse._from_response(
768+
response=response_dict, kwargs=parameter_model.model_dump()
769+
)
770+
771+
self._api_client._verify_response(return_value)
772+
yield return_value
773+
774+
return async_generator() # type: ignore[no-untyped-call, no-any-return]
718775

719776
async def _create_custom_job_resource(
720777
self,

0 commit comments

Comments
 (0)