1414
1515from __future__ import annotations
1616
17- from typing import Dict , Mapping , Optional , Union
17+ import collections .abc
18+ import json
19+ from typing import Any , Dict , List , Mapping , Optional , Union
1820
1921import bigframes .core .compile .googlesql as googlesql
2022import bigframes .core .sql
@@ -100,14 +102,41 @@ def create_model_ddl(
100102
101103
102104def _build_struct_sql (
103- struct_options : Mapping [str , Union [str , int , float , bool ]]
105+ struct_options : Mapping [
106+ str ,
107+ Union [str , int , float , bool , Mapping [str , str ], List [str ], Mapping [str , Any ]],
108+ ]
104109) -> str :
105110 if not struct_options :
106111 return ""
107112
108113 rendered_options = []
109114 for option_name , option_value in struct_options .items ():
110- rendered_val = bigframes .core .sql .simple_literal (option_value )
115+ if option_name == "model_params" :
116+ json_str = json .dumps (option_value )
117+ # Escape single quotes for SQL string literal
118+ sql_json_str = json_str .replace ("'" , "''" )
119+ rendered_val = f"JSON'{ sql_json_str } '"
120+ elif isinstance (option_value , collections .abc .Mapping ):
121+ struct_body = ", " .join (
122+ [
123+ f"{ bigframes .core .sql .simple_literal (v )} AS { k } "
124+ for k , v in option_value .items ()
125+ ]
126+ )
127+ rendered_val = f"STRUCT({ struct_body } )"
128+ elif isinstance (option_value , list ):
129+ rendered_val = (
130+ "["
131+ + ", " .join (
132+ [bigframes .core .sql .simple_literal (v ) for v in option_value ]
133+ )
134+ + "]"
135+ )
136+ elif isinstance (option_value , bool ):
137+ rendered_val = str (option_value ).lower ()
138+ else :
139+ rendered_val = bigframes .core .sql .simple_literal (option_value )
111140 rendered_options .append (f"{ rendered_val } AS { option_name } " )
112141 return f", STRUCT({ ', ' .join (rendered_options )} )"
113142
@@ -151,7 +180,7 @@ def predict(
151180 """Encode the ML.PREDICT statement.
152181 See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
153182 """
154- struct_options = {}
183+ struct_options : Dict [ str , Union [ str , int , float , bool ]] = {}
155184 if threshold is not None :
156185 struct_options ["threshold" ] = threshold
157186 if keep_original_columns is not None :
@@ -160,10 +189,10 @@ def predict(
160189 struct_options ["trial_id" ] = trial_id
161190
162191 sql = (
163- f"SELECT * FROM ML.PREDICT(MODEL { googlesql .identifier (model_name )} , ({ table } )"
192+ f"SELECT * FROM ML.PREDICT(MODEL { googlesql .identifier (model_name )} , ({ table } )) "
164193 )
165194 sql += _build_struct_sql (struct_options )
166- sql += ") \n "
195+ sql += "\n "
167196 return sql
168197
169198
@@ -205,13 +234,13 @@ def global_explain(
205234 """Encode the ML.GLOBAL_EXPLAIN statement.
206235 See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
207236 """
208- struct_options = {}
237+ struct_options : Dict [ str , Union [ str , int , float , bool ]] = {}
209238 if class_level_explain is not None :
210239 struct_options ["class_level_explain" ] = class_level_explain
211240
212- sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL { googlesql .identifier (model_name )} "
241+ sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL { googlesql .identifier (model_name )} ) "
213242 sql += _build_struct_sql (struct_options )
214- sql += ") \n "
243+ sql += "\n "
215244 return sql
216245
217246
@@ -224,3 +253,52 @@ def transform(
224253 """
225254 sql = f"SELECT * FROM ML.TRANSFORM(MODEL { googlesql .identifier (model_name )} , ({ table } ))\n "
226255 return sql
256+
257+
258+ def generate_text (
259+ model_name : str ,
260+ table : str ,
261+ * ,
262+ temperature : Optional [float ] = None ,
263+ max_output_tokens : Optional [int ] = None ,
264+ top_k : Optional [int ] = None ,
265+ top_p : Optional [float ] = None ,
266+ flatten_json_output : Optional [bool ] = None ,
267+ safety_settings : Optional [Mapping [str , str ]] = None ,
268+ stop_sequences : Optional [List [str ]] = None ,
269+ ground_with_google_search : Optional [bool ] = None ,
270+ model_params : Optional [Mapping [str , Any ]] = None ,
271+ request_type : Optional [str ] = None ,
272+ ) -> str :
273+ """Encode the ML.GENERATE_TEXT statement.
274+ See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text for reference.
275+ """
276+ struct_options : Dict [
277+ str ,
278+ Union [str , int , float , bool , Mapping [str , str ], List [str ], Mapping [str , Any ]],
279+ ] = {}
280+ if temperature is not None :
281+ struct_options ["temperature" ] = temperature
282+ if max_output_tokens is not None :
283+ struct_options ["max_output_tokens" ] = max_output_tokens
284+ if top_k is not None :
285+ struct_options ["top_k" ] = top_k
286+ if top_p is not None :
287+ struct_options ["top_p" ] = top_p
288+ if flatten_json_output is not None :
289+ struct_options ["flatten_json_output" ] = flatten_json_output
290+ if safety_settings is not None :
291+ struct_options ["safety_settings" ] = safety_settings
292+ if stop_sequences is not None :
293+ struct_options ["stop_sequences" ] = stop_sequences
294+ if ground_with_google_search is not None :
295+ struct_options ["ground_with_google_search" ] = ground_with_google_search
296+ if model_params is not None :
297+ struct_options ["model_params" ] = model_params
298+ if request_type is not None :
299+ struct_options ["request_type" ] = request_type
300+
301+ sql = f"SELECT * FROM ML.GENERATE_TEXT(MODEL { googlesql .identifier (model_name )} , ({ table } )"
302+ sql += _build_struct_sql (struct_options )
303+ sql += ")\n "
304+ return sql
0 commit comments