Skip to content

Commit 73a5aa1

Browse files
committed
feat: add bigquery.ml.generate_text function
1 parent 1f9ee37 commit 73a5aa1

File tree

12 files changed

+267
-17
lines changed

12 files changed

+267
-17
lines changed

GEMINI.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
## Testing
44

5-
We use `nox` to instrument our tests.
5+
We use `pytest` to instrument our tests.
66

7-
- To test your changes, run unit tests with `nox`:
7+
- To test your changes, run unit tests with `pytest`:
88

99
```bash
1010
nox -r -s unit

bigframes/bigquery/_operations/ml.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import cast, Mapping, Optional, Union
17+
from typing import Any, cast, List, Mapping, Optional, Union
1818

1919
import bigframes_vendored.constants
2020
import google.cloud.bigquery
@@ -431,3 +431,102 @@ def transform(
431431
return bpd.read_gbq_query(sql)
432432
else:
433433
return session.read_gbq_query(sql)
434+
435+
436+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
437+
def generate_text(
438+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
439+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
440+
*,
441+
temperature: Optional[float] = None,
442+
max_output_tokens: Optional[int] = None,
443+
top_k: Optional[int] = None,
444+
top_p: Optional[float] = None,
445+
flatten_json_output: Optional[bool] = None,
446+
safety_settings: Optional[Mapping[str, str]] = None,
447+
stop_sequences: Optional[List[str]] = None,
448+
ground_with_google_search: Optional[bool] = None,
449+
model_params: Optional[Mapping[str, Any]] = None,
450+
request_type: Optional[str] = None,
451+
) -> dataframe.DataFrame:
452+
"""
453+
Generates text using a BigQuery ML model.
454+
455+
See the `BigQuery ML GENERATE_TEXT function syntax
456+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
457+
for additional reference.
458+
459+
Args:
460+
model (bigframes.ml.base.BaseEstimator or str):
461+
The model to use for text generation.
462+
input_ (Union[bigframes.pandas.DataFrame, str]):
463+
The DataFrame or query to use for text generation.
464+
temperature (float, optional):
465+
A FLOAT64 value that is used for sampling promiscuity. The value
466+
must be in the range ``[0.0, 1.0]``. A lower temperature works well
467+
for prompts that expect a more deterministic and less open-ended
468+
or creative response, while a higher temperature can lead to more
469+
diverse or creative results. A temperature of ``0`` is
470+
deterministic, meaning that the highest probability response is
471+
always selected.
472+
max_output_tokens (int, optional):
473+
An INT64 value that sets the maximum number of tokens in the
474+
generated text.
475+
top_k (int, optional):
476+
An INT64 value that changes how the model selects tokens for
477+
output. A ``top_k`` of ``1`` means the next selected token is the
478+
most probable among all tokens in the model's vocabulary. A
479+
``top_k`` of ``3`` means that the next token is selected from
480+
among the three most probable tokens by using temperature. The
481+
default value is ``40``.
482+
top_p (float, optional):
483+
A FLOAT64 value that changes how the model selects tokens for
484+
output. Tokens are selected from most probable to least probable
485+
until the sum of their probabilities equals the ``top_p`` value.
486+
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
487+
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
488+
select either A or B as the next token by using temperature. The
489+
default value is ``0.95``.
490+
flatten_json_output (bool, optional):
491+
A BOOL value that determines the content of the generated JSON column.
492+
safety_settings (Mapping[str, str], optional):
493+
A STRUCT value that contains the safety settings for the model.
494+
The STRUCT must have a ``category`` field of type STRING and a
495+
``threshold`` field of type STRING.
496+
stop_sequences (List[str], optional):
497+
An ARRAY<STRING> value that contains the stop sequences for the model.
498+
ground_with_google_search (bool, optional):
499+
A BOOL value that determines whether to ground the model with Google Search.
500+
model_params (Mapping[str, Any], optional):
501+
A JSON value that contains the parameters for the model.
502+
request_type (str, optional):
503+
A STRING value that contains the request type for the model.
504+
505+
Returns:
506+
bigframes.pandas.DataFrame:
507+
The generated text.
508+
"""
509+
import bigframes.pandas as bpd
510+
511+
model_name, session = _get_model_name_and_session(model, input_)
512+
table_sql = _to_sql(input_)
513+
514+
sql = bigframes.core.sql.ml.generate_text(
515+
model_name=model_name,
516+
table=table_sql,
517+
temperature=temperature,
518+
max_output_tokens=max_output_tokens,
519+
top_k=top_k,
520+
top_p=top_p,
521+
flatten_json_output=flatten_json_output,
522+
safety_settings=safety_settings,
523+
stop_sequences=stop_sequences,
524+
ground_with_google_search=ground_with_google_search,
525+
model_params=model_params,
526+
request_type=request_type,
527+
)
528+
529+
if session is None:
530+
return bpd.read_gbq_query(sql)
531+
else:
532+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
create_model,
2424
evaluate,
2525
explain_predict,
26+
generate_text,
2627
global_explain,
2728
predict,
2829
transform,
@@ -35,4 +36,5 @@
3536
"explain_predict",
3637
"global_explain",
3738
"transform",
39+
"generate_text",
3840
]

bigframes/core/sql/ml.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Dict, Mapping, Optional, Union
17+
import collections.abc
18+
import json
19+
from typing import Any, Dict, List, Mapping, Optional, Union
1820

1921
import bigframes.core.compile.googlesql as googlesql
2022
import bigframes.core.sql
@@ -100,14 +102,41 @@ def create_model_ddl(
100102

101103

102104
def _build_struct_sql(
103-
struct_options: Mapping[str, Union[str, int, float, bool]]
105+
struct_options: Mapping[
106+
str,
107+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
108+
]
104109
) -> str:
105110
if not struct_options:
106111
return ""
107112

108113
rendered_options = []
109114
for option_name, option_value in struct_options.items():
110-
rendered_val = bigframes.core.sql.simple_literal(option_value)
115+
if option_name == "model_params":
116+
json_str = json.dumps(option_value)
117+
# Escape single quotes for SQL string literal
118+
sql_json_str = json_str.replace("'", "''")
119+
rendered_val = f"JSON'{sql_json_str}'"
120+
elif isinstance(option_value, collections.abc.Mapping):
121+
struct_body = ", ".join(
122+
[
123+
f"{bigframes.core.sql.simple_literal(v)} AS {k}"
124+
for k, v in option_value.items()
125+
]
126+
)
127+
rendered_val = f"STRUCT({struct_body})"
128+
elif isinstance(option_value, list):
129+
rendered_val = (
130+
"["
131+
+ ", ".join(
132+
[bigframes.core.sql.simple_literal(v) for v in option_value]
133+
)
134+
+ "]"
135+
)
136+
elif isinstance(option_value, bool):
137+
rendered_val = str(option_value).lower()
138+
else:
139+
rendered_val = bigframes.core.sql.simple_literal(option_value)
111140
rendered_options.append(f"{rendered_val} AS {option_name}")
112141
return f", STRUCT({', '.join(rendered_options)})"
113142

@@ -151,7 +180,7 @@ def predict(
151180
"""Encode the ML.PREDICT statement.
152181
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
153182
"""
154-
struct_options = {}
183+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
155184
if threshold is not None:
156185
struct_options["threshold"] = threshold
157186
if keep_original_columns is not None:
@@ -160,10 +189,10 @@ def predict(
160189
struct_options["trial_id"] = trial_id
161190

162191
sql = (
163-
f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
192+
f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table}))"
164193
)
165194
sql += _build_struct_sql(struct_options)
166-
sql += ")\n"
195+
sql += "\n"
167196
return sql
168197

169198

@@ -205,13 +234,13 @@ def global_explain(
205234
"""Encode the ML.GLOBAL_EXPLAIN statement.
206235
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
207236
"""
208-
struct_options = {}
237+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
209238
if class_level_explain is not None:
210239
struct_options["class_level_explain"] = class_level_explain
211240

212-
sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}"
241+
sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)})"
213242
sql += _build_struct_sql(struct_options)
214-
sql += ")\n"
243+
sql += "\n"
215244
return sql
216245

217246

@@ -224,3 +253,52 @@ def transform(
224253
"""
225254
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
226255
return sql
256+
257+
258+
def generate_text(
259+
model_name: str,
260+
table: str,
261+
*,
262+
temperature: Optional[float] = None,
263+
max_output_tokens: Optional[int] = None,
264+
top_k: Optional[int] = None,
265+
top_p: Optional[float] = None,
266+
flatten_json_output: Optional[bool] = None,
267+
safety_settings: Optional[Mapping[str, str]] = None,
268+
stop_sequences: Optional[List[str]] = None,
269+
ground_with_google_search: Optional[bool] = None,
270+
model_params: Optional[Mapping[str, Any]] = None,
271+
request_type: Optional[str] = None,
272+
) -> str:
273+
"""Encode the ML.GENERATE_TEXT statement.
274+
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference.
275+
"""
276+
struct_options: Dict[
277+
str,
278+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
279+
] = {}
280+
if temperature is not None:
281+
struct_options["temperature"] = temperature
282+
if max_output_tokens is not None:
283+
struct_options["max_output_tokens"] = max_output_tokens
284+
if top_k is not None:
285+
struct_options["top_k"] = top_k
286+
if top_p is not None:
287+
struct_options["top_p"] = top_p
288+
if flatten_json_output is not None:
289+
struct_options["flatten_json_output"] = flatten_json_output
290+
if safety_settings is not None:
291+
struct_options["safety_settings"] = safety_settings
292+
if stop_sequences is not None:
293+
struct_options["stop_sequences"] = stop_sequences
294+
if ground_with_google_search is not None:
295+
struct_options["ground_with_google_search"] = ground_with_google_search
296+
if model_params is not None:
297+
struct_options["model_params"] = model_params
298+
if request_type is not None:
299+
struct_options["request_type"] = request_type
300+
301+
sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL {googlesql.identifier(model_name)}, ({table})"
302+
sql += _build_struct_sql(struct_options)
303+
sql += ")\n"
304+
return sql

notebooks/ml/bq_dataframes_ml_cross_validation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@
991991
],
992992
"metadata": {
993993
"kernelspec": {
994-
"display_name": "venv",
994+
"display_name": "venv (3.10.14)",
995995
"language": "python",
996996
"name": "python3"
997997
},
@@ -1005,7 +1005,7 @@
10051005
"name": "python",
10061006
"nbconvert_exporter": "python",
10071007
"pygments_lexer": "ipython3",
1008-
"version": "3.10.15"
1008+
"version": "3.10.14"
10091009
}
10101010
},
10111011
"nbformat": 4,

tests/unit/bigquery/test_ml.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,46 @@ def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
163163
assert "ML.TRANSFORM" in generated_sql
164164
assert f"MODEL `{MODEL_NAME}`" in generated_sql
165165
assert "(SELECT * FROM `pandas_df`)" in generated_sql
166+
167+
168+
@mock.patch("bigframes.pandas.read_gbq_query")
169+
@mock.patch("bigframes.pandas.read_pandas")
170+
def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
171+
df = pd.DataFrame({"col1": [1, 2, 3]})
172+
read_pandas_mock.return_value._to_sql_query.return_value = (
173+
"SELECT * FROM `pandas_df`",
174+
[],
175+
[],
176+
)
177+
ml_ops.generate_text(
178+
MODEL_SERIES,
179+
input_=df,
180+
temperature=0.5,
181+
max_output_tokens=128,
182+
top_k=20,
183+
top_p=0.9,
184+
flatten_json_output=True,
185+
safety_settings={"hate_speech": "BLOCK_ONLY_HIGH"},
186+
stop_sequences=["a", "b"],
187+
ground_with_google_search=True,
188+
model_params={"param1": "value1"},
189+
request_type="TYPE",
190+
)
191+
read_pandas_mock.assert_called_once()
192+
read_gbq_query_mock.assert_called_once()
193+
generated_sql = read_gbq_query_mock.call_args[0][0]
194+
assert "ML.GENERATE_TEXT" in generated_sql
195+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
196+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
197+
assert "STRUCT(0.5 AS temperature" in generated_sql
198+
assert "128 AS max_output_tokens" in generated_sql
199+
assert "20 AS top_k" in generated_sql
200+
assert "0.9 AS top_p" in generated_sql
201+
assert "true AS flatten_json_output" in generated_sql
202+
assert (
203+
"STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings" in generated_sql
204+
)
205+
assert "['a', 'b'] AS stop_sequences" in generated_sql
206+
assert "true AS ground_with_google_search" in generated_sql
207+
assert 'JSON"{"param1": "value1"}"' in generated_sql
208+
assert "'TYPE' AS request_type" in generated_sql
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
1+
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, STRUCT('BLOCK_ONLY_HIGH' AS hate_speech) AS safety_settings, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, JSON'{"param1": "value1"}' AS model_params, 'TYPE' AS request_type))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain))
1+
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`), STRUCT(true AS class_level_explain)

0 commit comments

Comments
 (0)