1919from __future__ import annotations
2020
2121import json
22- from typing import Any , List , Literal , Mapping , Tuple
22+ from typing import Any , List , Literal , Mapping , Tuple , Union
2323
24- from bigframes import clients , dtypes , series
25- from bigframes .core import log_adapter
24+ import pandas as pd
25+
26+ from bigframes import clients , dtypes , series , session
27+ from bigframes .core import convert , log_adapter
2628from bigframes .operations import ai_ops
2729
30+ PROMPT_TYPE = Union [
31+ series .Series ,
32+ pd .Series ,
33+ List [Union [str , series .Series , pd .Series ]],
34+ Tuple [Union [str , series .Series , pd .Series ], ...],
35+ ]
36+
2837
2938@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
3039def generate_bool (
31- prompt : series . Series | List [ str | series . Series ] | Tuple [ str | series . Series , ...] ,
40+ prompt : PROMPT_TYPE ,
3241 * ,
3342 connection_id : str | None = None ,
3443 endpoint : str | None = None ,
@@ -51,7 +60,7 @@ def generate_bool(
5160 0 {'result': True, 'full_response': '{"candidate...
5261 1 {'result': True, 'full_response': '{"candidate...
5362 2 {'result': False, 'full_response': '{"candidat...
54- dtype: struct<result: bool, full_response: string , status: string>[pyarrow]
63+ dtype: struct<result: bool, full_response: extension<dbjson<JSONArrowType>> , status: string>[pyarrow]
5564
5665 >>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
5766 0 True
@@ -60,8 +69,9 @@ def generate_bool(
6069 Name: result, dtype: boolean
6170
6271 Args:
63- prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
64- A mixture of Series and string literals that specifies the prompt to send to the model.
72+ prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
73+ A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
74+ or pandas Series.
6575 connection_id (str, optional):
6676 Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
6777 If not provided, the connection from the current session will be used.
@@ -84,7 +94,7 @@ def generate_bool(
8494 Returns:
8595 bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
8696 * "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
87- * "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model.
97+ * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
8898 The generated text is in the text element.
8999 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
90100 """
@@ -104,7 +114,7 @@ def generate_bool(
104114
105115
106116def _separate_context_and_series (
107- prompt : series . Series | List [ str | series . Series ] | Tuple [ str | series . Series , ...] ,
117+ prompt : PROMPT_TYPE ,
108118) -> Tuple [List [str | None ], List [series .Series ]]:
109119 """
110120 Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series
@@ -123,18 +133,19 @@ def _separate_context_and_series(
123133 return [None ], [prompt ]
124134
125135 prompt_context : List [str | None ] = []
126- series_list : List [series .Series ] = []
136+ series_list : List [series .Series | pd . Series ] = []
127137
138+ session = None
128139 for item in prompt :
129140 if isinstance (item , str ):
130141 prompt_context .append (item )
131142
132- elif isinstance (item , series .Series ):
143+ elif isinstance (item , ( series .Series , pd . Series ) ):
133144 prompt_context .append (None )
134145
135- if item . dtype == dtypes . OBJ_REF_DTYPE :
136- # Multi-model support
137- item = item .blob . read_url ()
146+ if isinstance ( item , series . Series ) and session is None :
147+ # Use the first available BF session if there's any.
148+ session = item ._session
138149 series_list .append (item )
139150
140151 else :
@@ -143,7 +154,20 @@ def _separate_context_and_series(
143154 if not series_list :
144155 raise ValueError ("Please provide at least one Series in the prompt" )
145156
146- return prompt_context , series_list
157+ converted_list = [_convert_series (s , session ) for s in series_list ]
158+
159+ return prompt_context , converted_list
160+
161+
162+ def _convert_series (
163+ s : series .Series | pd .Series , session : session .Session | None
164+ ) -> series .Series :
165+ result = convert .to_bf_series (s , default_index = None , session = session )
166+
167+ if result .dtype == dtypes .OBJ_REF_DTYPE :
168+ # Support multimodel
169+ return result .blob .read_url ()
170+ return result
147171
148172
149173def _resolve_connection_id (series : series .Series , connection_id : str | None ):
0 commit comments