19
19
import json
20
20
import logging
21
21
import time
22
- from typing import Any , Optional , Union
22
+ from typing import Any , AsyncIterator , Awaitable , Iterator , Optional , Union
23
23
from urllib .parse import urlencode
24
24
25
25
from google .genai import _api_module
@@ -39,6 +39,9 @@ def _OptimizeRequestParameters_to_vertex(
39
39
parent_object : Optional [dict [str , Any ]] = None ,
40
40
) -> dict [str , Any ]:
41
41
to_object : dict [str , Any ] = {}
42
+ if getv (from_object , ["content" ]) is not None :
43
+ setv (to_object , ["content" ], getv (from_object , ["content" ]))
44
+
42
45
if getv (from_object , ["config" ]) is not None :
43
46
setv (to_object , ["config" ], getv (from_object , ["config" ]))
44
47
@@ -229,6 +232,8 @@ def _OptimizeResponse_from_vertex(
229
232
parent_object : Optional [dict [str , Any ]] = None ,
230
233
) -> dict [str , Any ]:
231
234
to_object : dict [str , Any ] = {}
235
+ if getv (from_object , ["content" ]) is not None :
236
+ setv (to_object , ["content" ], getv (from_object , ["content" ]))
232
237
233
238
return to_object
234
239
@@ -383,12 +388,16 @@ def _CustomJob_from_vertex(
383
388
class PromptOptimizer (_api_module .BaseModule ):
384
389
"""Prompt Optimizer"""
385
390
386
- def _optimize_dummy (
387
- self , * , config : Optional [types .OptimizeConfigOrDict ] = None
388
- ) -> types .OptimizeResponse :
389
- """Optimize multiple prompts."""
391
+ def _optimize_prompt (
392
+ self ,
393
+ * ,
394
+ content : Optional [types .ContentOrDict ] = None ,
395
+ config : Optional [types .OptimizeConfigOrDict ] = None ,
396
+ ) -> Iterator [types .OptimizeResponse ]:
397
+ """Optimize a single prompt."""
390
398
391
399
parameter_model = types ._OptimizeRequestParameters (
400
+ content = content ,
392
401
config = config ,
393
402
)
394
403
@@ -399,9 +408,9 @@ def _optimize_dummy(
399
408
request_dict = _OptimizeRequestParameters_to_vertex (parameter_model )
400
409
request_url_dict = request_dict .get ("_url" )
401
410
if request_url_dict :
402
- path = ":optimize " .format_map (request_url_dict )
411
+ path = "tuningJobs:optimizePrompt " .format_map (request_url_dict )
403
412
else :
404
- path = ":optimize "
413
+ path = "tuningJobs:optimizePrompt "
405
414
406
415
query_params = request_dict .get ("_query" )
407
416
if query_params :
@@ -419,19 +428,27 @@ def _optimize_dummy(
419
428
request_dict = _common .convert_to_dict (request_dict )
420
429
request_dict = _common .encode_unserializable_types (request_dict )
421
430
422
- response = self ._api_client .request ("post" , path , request_dict , http_options )
431
+ if config is not None and getattr (config , "should_return_http_response" , None ):
432
+ raise ValueError (
433
+ "Accessing the raw HTTP response is not supported in streaming"
434
+ " methods."
435
+ )
423
436
424
- response_dict = "" if not response .body else json .loads (response .body )
437
+ for response in self ._api_client .request_streamed (
438
+ "post" , path , request_dict , http_options
439
+ ):
425
440
426
- if self ._api_client .vertexai :
427
- response_dict = _OptimizeResponse_from_vertex (response_dict )
441
+ response_dict = "" if not response .body else json .loads (response .body )
428
442
429
- return_value = types .OptimizeResponse ._from_response (
430
- response = response_dict , kwargs = parameter_model .model_dump ()
431
- )
443
+ if self ._api_client .vertexai :
444
+ response_dict = _OptimizeResponse_from_vertex (response_dict )
432
445
433
- self ._api_client ._verify_response (return_value )
434
- return return_value
446
+ return_value = types .OptimizeResponse ._from_response (
447
+ response = response_dict , kwargs = parameter_model .model_dump ()
448
+ )
449
+
450
+ self ._api_client ._verify_response (return_value )
451
+ yield return_value
435
452
436
453
def _create_custom_job_resource (
437
454
self ,
@@ -660,16 +677,45 @@ def optimize(
660
677
job = self ._wait_for_completion (job_id )
661
678
return job
662
679
680
+ def optimize_prompt (
681
+ self , * , prompt : str , config : Optional [types .OptimizeConfig ] = None
682
+ ) -> Iterator [types .OptimizeResponse ]:
683
+ """Makes an API request to _optimize_prompt and yields the optimized prompt in chunks."""
684
+ if config is not None :
685
+ raise ValueError (
686
+ "Currently, config is not supported for a single prompt"
687
+ " optimization."
688
+ )
689
+
690
+ prompt = types .Content (parts = [types .Part (text = prompt )], role = "user" )
691
+ # response = self._optimize_prompt(content=prompt)
692
+ # logger.info(type(response))
693
+ # logger.info(response)
694
+
695
+ # for chunk in response:
696
+ # yield chunk
697
+ for chunk in self ._optimize_prompt (content = prompt ):
698
+ # logger.info(chunk)
699
+ # if chunk.content and chunk.content.parts[0].text:
700
+ # logger.info('chunk has content text %s', chunk.content.parts[0].text)
701
+ # if chunk.parts[0]:
702
+ # logger.info('chunk has parts %s', chunk.parts[0])
703
+ yield chunk
704
+
663
705
664
706
class AsyncPromptOptimizer (_api_module .BaseModule ):
665
707
"""Prompt Optimizer"""
666
708
667
- async def _optimize_dummy (
668
- self , * , config : Optional [types .OptimizeConfigOrDict ] = None
669
- ) -> types .OptimizeResponse :
670
- """Optimize multiple prompts."""
709
+ async def _optimize_prompt (
710
+ self ,
711
+ * ,
712
+ content : Optional [types .ContentOrDict ] = None ,
713
+ config : Optional [types .OptimizeConfigOrDict ] = None ,
714
+ ) -> Awaitable [AsyncIterator [types .OptimizeResponse ]]:
715
+ """Optimize a single prompt."""
671
716
672
717
parameter_model = types ._OptimizeRequestParameters (
718
+ content = content ,
673
719
config = config ,
674
720
)
675
721
@@ -680,9 +726,9 @@ async def _optimize_dummy(
680
726
request_dict = _OptimizeRequestParameters_to_vertex (parameter_model )
681
727
request_url_dict = request_dict .get ("_url" )
682
728
if request_url_dict :
683
- path = ":optimize " .format_map (request_url_dict )
729
+ path = "tuningJobs:optimizePrompt " .format_map (request_url_dict )
684
730
else :
685
- path = ":optimize "
731
+ path = "tuningJobs:optimizePrompt "
686
732
687
733
query_params = request_dict .get ("_query" )
688
734
if query_params :
@@ -700,21 +746,32 @@ async def _optimize_dummy(
700
746
request_dict = _common .convert_to_dict (request_dict )
701
747
request_dict = _common .encode_unserializable_types (request_dict )
702
748
703
- response = await self ._api_client .async_request (
749
+ if config is not None and getattr (config , "should_return_http_response" , None ):
750
+ raise ValueError (
751
+ "Accessing the raw HTTP response is not supported in streaming"
752
+ " methods."
753
+ )
754
+
755
+ response_stream = await self ._api_client .async_request_streamed (
704
756
"post" , path , request_dict , http_options
705
757
)
706
758
707
- response_dict = "" if not response .body else json .loads (response .body )
759
+ async def async_generator (): # type: ignore[no-untyped-def]
760
+ async for response in response_stream :
708
761
709
- if self ._api_client .vertexai :
710
- response_dict = _OptimizeResponse_from_vertex (response_dict )
762
+ response_dict = "" if not response .body else json .loads (response .body )
711
763
712
- return_value = types .OptimizeResponse ._from_response (
713
- response = response_dict , kwargs = parameter_model .model_dump ()
714
- )
764
+ if self ._api_client .vertexai :
765
+ response_dict = _OptimizeResponse_from_vertex (response_dict )
715
766
716
- self ._api_client ._verify_response (return_value )
717
- return return_value
767
+ return_value = types .OptimizeResponse ._from_response (
768
+ response = response_dict , kwargs = parameter_model .model_dump ()
769
+ )
770
+
771
+ self ._api_client ._verify_response (return_value )
772
+ yield return_value
773
+
774
+ return async_generator () # type: ignore[no-untyped-call, no-any-return]
718
775
719
776
async def _create_custom_job_resource (
720
777
self ,
0 commit comments