Skip to content

Commit 43495ee

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - add zero-shot prompt optimizer (streaming): an option to quickly improve or generate system instructions or a single prompt.
PiperOrigin-RevId: 788830431
1 parent 8cfd9ba commit 43495ee

File tree

3 files changed

+845
-695
lines changed

3 files changed

+845
-695
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
import logging
18+
19+
from tests.unit.vertexai.genai.replays import pytest_helper
20+
21+
# from vertexai._genai import types
22+
23+
logger = logging.getLogger("vertexai_genai.promptoptimizer")
24+
logging.basicConfig(encoding="utf-8", level=logging.INFO, force=True)
25+
26+
27+
def test_optimize_prompt(client):
28+
"""Tests the optimize request parameters method."""
29+
30+
client._api_client._http_options.base_url = (
31+
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com"
32+
)
33+
test_prompt = "Generate system instructions for analyzing medical articles"
34+
for chunk in client.prompt_optimizer.optimize_prompt(prompt=test_prompt):
35+
print("chunk: %s" % chunk)
36+
# logger.info("response: %s", response)
37+
# assert isinstance(response, types.OptimizeResponse)
38+
39+
40+
pytestmark = pytest_helper.setup(
41+
file=__file__,
42+
globals_for_file=globals(),
43+
test_method="prompt_optimizer.optimize_prompt",
44+
)

vertexai/_genai/prompt_optimizer.py

Lines changed: 97 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
import logging
2121
import time
22-
from typing import Any, Optional, Union
22+
from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union
2323
from urllib.parse import urlencode
2424

2525
from google.genai import _api_module
@@ -32,13 +32,17 @@
3232

3333

3434
logger = logging.getLogger("vertexai_genai.promptoptimizer")
35+
logging.basicConfig(encoding="utf-8", level=logging.INFO, force=True)
3536

3637

3738
def _OptimizeRequestParameters_to_vertex(
3839
from_object: Union[dict[str, Any], object],
3940
parent_object: Optional[dict[str, Any]] = None,
4041
) -> dict[str, Any]:
4142
to_object: dict[str, Any] = {}
43+
if getv(from_object, ["content"]) is not None:
44+
setv(to_object, ["content"], getv(from_object, ["content"]))
45+
4246
if getv(from_object, ["config"]) is not None:
4347
setv(to_object, ["config"], getv(from_object, ["config"]))
4448

@@ -229,6 +233,8 @@ def _OptimizeResponse_from_vertex(
229233
parent_object: Optional[dict[str, Any]] = None,
230234
) -> dict[str, Any]:
231235
to_object: dict[str, Any] = {}
236+
if getv(from_object, ["content"]) is not None:
237+
setv(to_object, ["content"], getv(from_object, ["content"]))
232238

233239
return to_object
234240

@@ -383,12 +389,16 @@ def _CustomJob_from_vertex(
383389
class PromptOptimizer(_api_module.BaseModule):
384390
"""Prompt Optimizer"""
385391

386-
def _optimize_dummy(
387-
self, *, config: Optional[types.OptimizeConfigOrDict] = None
388-
) -> types.OptimizeResponse:
389-
"""Optimize multiple prompts."""
392+
def _optimize_prompt(
393+
self,
394+
*,
395+
content: Optional[types.ContentOrDict] = None,
396+
config: Optional[types.OptimizeConfigOrDict] = None,
397+
) -> Iterator[types.OptimizeResponse]:
398+
"""Optimize a single prompt."""
390399

391400
parameter_model = types._OptimizeRequestParameters(
401+
content=content,
392402
config=config,
393403
)
394404

@@ -399,9 +409,9 @@ def _optimize_dummy(
399409
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
400410
request_url_dict = request_dict.get("_url")
401411
if request_url_dict:
402-
path = ":optimize".format_map(request_url_dict)
412+
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
403413
else:
404-
path = ":optimize"
414+
path = "tuningJobs:optimizePrompt"
405415

406416
query_params = request_dict.get("_query")
407417
if query_params:
@@ -419,19 +429,32 @@ def _optimize_dummy(
419429
request_dict = _common.convert_to_dict(request_dict)
420430
request_dict = _common.encode_unserializable_types(request_dict)
421431

422-
response = self._api_client.request("post", path, request_dict, http_options)
423-
424-
response_dict = "" if not response.body else json.loads(response.body)
425-
426-
if self._api_client.vertexai:
427-
response_dict = _OptimizeResponse_from_vertex(response_dict)
428-
429-
return_value = types.OptimizeResponse._from_response(
430-
response=response_dict, kwargs=parameter_model.model_dump()
431-
)
432+
if config is not None and getattr(config, "should_return_http_response", None):
433+
raise ValueError(
434+
"Accessing the raw HTTP response is not supported in streaming"
435+
" methods."
436+
)
432437

433-
self._api_client._verify_response(return_value)
434-
return return_value
438+
for response in self._api_client.request_streamed(
439+
"post", path, request_dict, http_options
440+
):
441+
# print("response: %s" % response)
442+
logger.info("response: %s", response)
443+
response_dict = "" if not response.body else json.loads(response.body)
444+
# print("response_dict: %s" % response_dict)
445+
logger.info("response_dict: %s", response_dict)
446+
if self._api_client.vertexai:
447+
response_dict = _OptimizeResponse_from_vertex(response_dict)
448+
# print("response_dict vertexai: %s" % response_dict)
449+
logger.info("response_dict vertexai: %s", response_dict)
450+
451+
return_value = types.OptimizeResponse._from_response(
452+
response=response_dict, kwargs=parameter_model.model_dump()
453+
)
454+
# print("return_value: %s" % return_value)
455+
logger.info("return_value: %s", return_value)
456+
self._api_client._verify_response(return_value)
457+
yield return_value
435458

436459
def _create_custom_job_resource(
437460
self,
@@ -660,16 +683,45 @@ def optimize(
660683
job = self._wait_for_completion(job_id)
661684
return job
662685

686+
def optimize_prompt(
687+
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
688+
) -> Iterator[types.OptimizeResponse]:
689+
"""Makes an API request to _optimize_prompt and yields the optimized prompt in chunks."""
690+
if config is not None:
691+
raise ValueError(
692+
"Currently, config is not supported for a single prompt"
693+
" optimization."
694+
)
695+
696+
prompt = types.Content(parts=[types.Part(text=prompt)], role="user")
697+
# response = self._optimize_prompt(content=prompt)
698+
# logger.info(type(response))
699+
# logger.info(response)
700+
701+
# for chunk in response:
702+
# yield chunk
703+
for chunk in self._optimize_prompt(content=prompt):
704+
# logger.info(chunk)
705+
# if chunk.content and chunk.content.parts[0].text:
706+
# logger.info('chunk has content text %s', chunk.content.parts[0].text)
707+
# if chunk.parts[0]:
708+
# logger.info('chunk has parts %s', chunk.parts[0])
709+
yield chunk
710+
663711

664712
class AsyncPromptOptimizer(_api_module.BaseModule):
665713
"""Prompt Optimizer"""
666714

667-
async def _optimize_dummy(
668-
self, *, config: Optional[types.OptimizeConfigOrDict] = None
669-
) -> types.OptimizeResponse:
670-
"""Optimize multiple prompts."""
715+
async def _optimize_prompt(
716+
self,
717+
*,
718+
content: Optional[types.ContentOrDict] = None,
719+
config: Optional[types.OptimizeConfigOrDict] = None,
720+
) -> Awaitable[AsyncIterator[types.OptimizeResponse]]:
721+
"""Optimize a single prompt."""
671722

672723
parameter_model = types._OptimizeRequestParameters(
724+
content=content,
673725
config=config,
674726
)
675727

@@ -680,9 +732,9 @@ async def _optimize_dummy(
680732
request_dict = _OptimizeRequestParameters_to_vertex(parameter_model)
681733
request_url_dict = request_dict.get("_url")
682734
if request_url_dict:
683-
path = ":optimize".format_map(request_url_dict)
735+
path = "tuningJobs:optimizePrompt".format_map(request_url_dict)
684736
else:
685-
path = ":optimize"
737+
path = "tuningJobs:optimizePrompt"
686738

687739
query_params = request_dict.get("_query")
688740
if query_params:
@@ -700,21 +752,32 @@ async def _optimize_dummy(
700752
request_dict = _common.convert_to_dict(request_dict)
701753
request_dict = _common.encode_unserializable_types(request_dict)
702754

703-
response = await self._api_client.async_request(
755+
if config is not None and getattr(config, "should_return_http_response", None):
756+
raise ValueError(
757+
"Accessing the raw HTTP response is not supported in streaming"
758+
" methods."
759+
)
760+
761+
response_stream = await self._api_client.async_request_streamed(
704762
"post", path, request_dict, http_options
705763
)
706764

707-
response_dict = "" if not response.body else json.loads(response.body)
765+
async def async_generator(): # type: ignore[no-untyped-def]
766+
async for response in response_stream:
708767

709-
if self._api_client.vertexai:
710-
response_dict = _OptimizeResponse_from_vertex(response_dict)
768+
response_dict = "" if not response.body else json.loads(response.body)
711769

712-
return_value = types.OptimizeResponse._from_response(
713-
response=response_dict, kwargs=parameter_model.model_dump()
714-
)
770+
if self._api_client.vertexai:
771+
response_dict = _OptimizeResponse_from_vertex(response_dict)
715772

716-
self._api_client._verify_response(return_value)
717-
return return_value
773+
return_value = types.OptimizeResponse._from_response(
774+
response=response_dict, kwargs=parameter_model.model_dump()
775+
)
776+
777+
self._api_client._verify_response(return_value)
778+
yield return_value
779+
780+
return async_generator() # type: ignore[no-untyped-call, no-any-return]
718781

719782
async def _create_custom_job_resource(
720783
self,

0 commit comments

Comments
 (0)