19
19
import json
20
20
import logging
21
21
import time
22
- from typing import Any , Optional , Union
22
+ from typing import Any , AsyncIterator , Awaitable , Iterator , Optional , Union
23
23
from urllib .parse import urlencode
24
24
25
25
from google .genai import _api_module
32
32
33
33
34
34
logger = logging .getLogger ("vertexai_genai.promptoptimizer" )
35
+ logging .basicConfig (encoding = "utf-8" , level = logging .INFO , force = True )
35
36
36
37
37
38
def _OptimizeRequestParameters_to_vertex (
38
39
from_object : Union [dict [str , Any ], object ],
39
40
parent_object : Optional [dict [str , Any ]] = None ,
40
41
) -> dict [str , Any ]:
41
42
to_object : dict [str , Any ] = {}
43
+ if getv (from_object , ["content" ]) is not None :
44
+ setv (to_object , ["content" ], getv (from_object , ["content" ]))
45
+
42
46
if getv (from_object , ["config" ]) is not None :
43
47
setv (to_object , ["config" ], getv (from_object , ["config" ]))
44
48
@@ -229,6 +233,8 @@ def _OptimizeResponse_from_vertex(
229
233
parent_object : Optional [dict [str , Any ]] = None ,
230
234
) -> dict [str , Any ]:
231
235
to_object : dict [str , Any ] = {}
236
+ if getv (from_object , ["content" ]) is not None :
237
+ setv (to_object , ["content" ], getv (from_object , ["content" ]))
232
238
233
239
return to_object
234
240
@@ -383,12 +389,16 @@ def _CustomJob_from_vertex(
383
389
class PromptOptimizer (_api_module .BaseModule ):
384
390
"""Prompt Optimizer"""
385
391
386
- def _optimize_dummy (
387
- self , * , config : Optional [types .OptimizeConfigOrDict ] = None
388
- ) -> types .OptimizeResponse :
389
- """Optimize multiple prompts."""
392
+ def _optimize_prompt (
393
+ self ,
394
+ * ,
395
+ content : Optional [types .ContentOrDict ] = None ,
396
+ config : Optional [types .OptimizeConfigOrDict ] = None ,
397
+ ) -> Iterator [types .OptimizeResponse ]:
398
+ """Optimize a single prompt."""
390
399
391
400
parameter_model = types ._OptimizeRequestParameters (
401
+ content = content ,
392
402
config = config ,
393
403
)
394
404
@@ -399,9 +409,9 @@ def _optimize_dummy(
399
409
request_dict = _OptimizeRequestParameters_to_vertex (parameter_model )
400
410
request_url_dict = request_dict .get ("_url" )
401
411
if request_url_dict :
402
- path = ":optimize " .format_map (request_url_dict )
412
+ path = "tuningJobs:optimizePrompt " .format_map (request_url_dict )
403
413
else :
404
- path = ":optimize "
414
+ path = "tuningJobs:optimizePrompt "
405
415
406
416
query_params = request_dict .get ("_query" )
407
417
if query_params :
@@ -419,19 +429,32 @@ def _optimize_dummy(
419
429
request_dict = _common .convert_to_dict (request_dict )
420
430
request_dict = _common .encode_unserializable_types (request_dict )
421
431
422
- response = self ._api_client .request ("post" , path , request_dict , http_options )
423
-
424
- response_dict = "" if not response .body else json .loads (response .body )
425
-
426
- if self ._api_client .vertexai :
427
- response_dict = _OptimizeResponse_from_vertex (response_dict )
428
-
429
- return_value = types .OptimizeResponse ._from_response (
430
- response = response_dict , kwargs = parameter_model .model_dump ()
431
- )
432
+ if config is not None and getattr (config , "should_return_http_response" , None ):
433
+ raise ValueError (
434
+ "Accessing the raw HTTP response is not supported in streaming"
435
+ " methods."
436
+ )
432
437
433
- self ._api_client ._verify_response (return_value )
434
- return return_value
438
+ for response in self ._api_client .request_streamed (
439
+ "post" , path , request_dict , http_options
440
+ ):
441
+ # print("response: %s" % response)
442
+ logger .info ("response: %s" , response )
443
+ response_dict = "" if not response .body else json .loads (response .body )
444
+ # print("response_dict: %s" % response_dict)
445
+ logger .info ("response_dict: %s" , response_dict )
446
+ if self ._api_client .vertexai :
447
+ response_dict = _OptimizeResponse_from_vertex (response_dict )
448
+ # print("response_dict vertexai: %s" % response_dict)
449
+ logger .info ("response_dict vertexai: %s" , response_dict )
450
+
451
+ return_value = types .OptimizeResponse ._from_response (
452
+ response = response_dict , kwargs = parameter_model .model_dump ()
453
+ )
454
+ # print("return_value: %s" % return_value)
455
+ logger .info ("return_value: %s" , return_value )
456
+ self ._api_client ._verify_response (return_value )
457
+ yield return_value
435
458
436
459
def _create_custom_job_resource (
437
460
self ,
@@ -660,16 +683,45 @@ def optimize(
660
683
job = self ._wait_for_completion (job_id )
661
684
return job
662
685
686
+ def optimize_prompt (
687
+ self , * , prompt : str , config : Optional [types .OptimizeConfig ] = None
688
+ ) -> Iterator [types .OptimizeResponse ]:
689
+ """Makes an API request to _optimize_prompt and yields the optimized prompt in chunks."""
690
+ if config is not None :
691
+ raise ValueError (
692
+ "Currently, config is not supported for a single prompt"
693
+ " optimization."
694
+ )
695
+
696
+ prompt = types .Content (parts = [types .Part (text = prompt )], role = "user" )
697
+ # response = self._optimize_prompt(content=prompt)
698
+ # logger.info(type(response))
699
+ # logger.info(response)
700
+
701
+ # for chunk in response:
702
+ # yield chunk
703
+ for chunk in self ._optimize_prompt (content = prompt ):
704
+ # logger.info(chunk)
705
+ # if chunk.content and chunk.content.parts[0].text:
706
+ # logger.info('chunk has content text %s', chunk.content.parts[0].text)
707
+ # if chunk.parts[0]:
708
+ # logger.info('chunk has parts %s', chunk.parts[0])
709
+ yield chunk
710
+
663
711
664
712
class AsyncPromptOptimizer (_api_module .BaseModule ):
665
713
"""Prompt Optimizer"""
666
714
667
- async def _optimize_dummy (
668
- self , * , config : Optional [types .OptimizeConfigOrDict ] = None
669
- ) -> types .OptimizeResponse :
670
- """Optimize multiple prompts."""
715
+ async def _optimize_prompt (
716
+ self ,
717
+ * ,
718
+ content : Optional [types .ContentOrDict ] = None ,
719
+ config : Optional [types .OptimizeConfigOrDict ] = None ,
720
+ ) -> Awaitable [AsyncIterator [types .OptimizeResponse ]]:
721
+ """Optimize a single prompt."""
671
722
672
723
parameter_model = types ._OptimizeRequestParameters (
724
+ content = content ,
673
725
config = config ,
674
726
)
675
727
@@ -680,9 +732,9 @@ async def _optimize_dummy(
680
732
request_dict = _OptimizeRequestParameters_to_vertex (parameter_model )
681
733
request_url_dict = request_dict .get ("_url" )
682
734
if request_url_dict :
683
- path = ":optimize " .format_map (request_url_dict )
735
+ path = "tuningJobs:optimizePrompt " .format_map (request_url_dict )
684
736
else :
685
- path = ":optimize "
737
+ path = "tuningJobs:optimizePrompt "
686
738
687
739
query_params = request_dict .get ("_query" )
688
740
if query_params :
@@ -700,21 +752,32 @@ async def _optimize_dummy(
700
752
request_dict = _common .convert_to_dict (request_dict )
701
753
request_dict = _common .encode_unserializable_types (request_dict )
702
754
703
- response = await self ._api_client .async_request (
755
+ if config is not None and getattr (config , "should_return_http_response" , None ):
756
+ raise ValueError (
757
+ "Accessing the raw HTTP response is not supported in streaming"
758
+ " methods."
759
+ )
760
+
761
+ response_stream = await self ._api_client .async_request_streamed (
704
762
"post" , path , request_dict , http_options
705
763
)
706
764
707
- response_dict = "" if not response .body else json .loads (response .body )
765
+ async def async_generator (): # type: ignore[no-untyped-def]
766
+ async for response in response_stream :
708
767
709
- if self ._api_client .vertexai :
710
- response_dict = _OptimizeResponse_from_vertex (response_dict )
768
+ response_dict = "" if not response .body else json .loads (response .body )
711
769
712
- return_value = types .OptimizeResponse ._from_response (
713
- response = response_dict , kwargs = parameter_model .model_dump ()
714
- )
770
+ if self ._api_client .vertexai :
771
+ response_dict = _OptimizeResponse_from_vertex (response_dict )
715
772
716
- self ._api_client ._verify_response (return_value )
717
- return return_value
773
+ return_value = types .OptimizeResponse ._from_response (
774
+ response = response_dict , kwargs = parameter_model .model_dump ()
775
+ )
776
+
777
+ self ._api_client ._verify_response (return_value )
778
+ yield return_value
779
+
780
+ return async_generator () # type: ignore[no-untyped-call, no-any-return]
718
781
719
782
async def _create_custom_job_resource (
720
783
self ,
0 commit comments