Skip to content

Commit e804e13

Browse files
authored
refactor: push down SQL generate logic in core.BqmlModelFactory (#62)
1 parent a6e32aa commit e804e13

File tree

3 files changed

+64
-46
lines changed

3 files changed

+64
-46
lines changed

bigframes/ml/core.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,10 @@ def create_model(
251251

252252
session = X_train._session
253253

254-
source_sql = input_data.sql
255-
options_sql = self._model_creation_sql_generator.options(**options)
256-
transform_sql = (
257-
self._model_creation_sql_generator.transform(*transforms)
258-
if transforms is not None
259-
else None
260-
)
261254
sql = self._model_creation_sql_generator.create_model(
262-
source_sql=source_sql,
263-
transform_sql=transform_sql,
264-
options_sql=options_sql,
255+
source=input_data,
256+
transforms=transforms,
257+
options=options,
265258
)
266259

267260
return self._create_model_with_sql(session=session, sql=sql)
@@ -287,18 +280,10 @@ def create_time_series_model(
287280

288281
session = X_train._session
289282

290-
source_sql = input_data.sql
291-
options_sql = self._model_creation_sql_generator.options(**options)
292-
293-
transform_sql = (
294-
self._model_creation_sql_generator.transform(*transforms)
295-
if transforms is not None
296-
else None
297-
)
298283
sql = self._model_creation_sql_generator.create_model(
299-
source_sql=source_sql,
300-
transform_sql=transform_sql,
301-
options_sql=options_sql,
284+
source=input_data,
285+
transforms=transforms,
286+
options=options,
302287
)
303288

304289
return self._create_model_with_sql(session=session, sql=sql)
@@ -320,10 +305,9 @@ def create_remote_model(
320305
Returns:
321306
BqmlModel: a BqmlModel wrapping a trained model in BigQuery
322307
"""
323-
options_sql = self._model_creation_sql_generator.options(**options)
324308
sql = self._model_creation_sql_generator.create_remote_model(
325309
connection_name=connection_name,
326-
options_sql=options_sql,
310+
options=options,
327311
)
328312

329313
return self._create_model_with_sql(session=session, sql=sql)
@@ -341,9 +325,8 @@ def create_imported_model(
341325
342326
Returns: a BqmlModel, wrapping a trained model in BigQuery
343327
"""
344-
options_sql = self._model_creation_sql_generator.options(**options)
345328
sql = self._model_creation_sql_generator.create_imported_model(
346-
options_sql=options_sql,
329+
options=options,
347330
)
348331

349332
return self._create_model_with_sql(session=session, sql=sql)

bigframes/ml/sql.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
Generates SQL queries needed for BigQuery DataFrames ML
1717
"""
1818

19-
from typing import Iterable, Optional, Union
19+
from typing import Iterable, Mapping, Optional, Union
2020

2121
import bigframes.constants as constants
22+
import bigframes.pandas as bpd
2223

2324

2425
class BaseSqlGenerator:
@@ -113,11 +114,15 @@ def __init__(self, model_id: str):
113114
# Model create and alter
114115
def create_model(
115116
self,
116-
source_sql: str,
117-
transform_sql: Optional[str] = None,
118-
options_sql: Optional[str] = None,
117+
source: bpd.DataFrame,
118+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
119+
transforms: Optional[Iterable[str]] = None,
119120
) -> str:
120121
"""Encode the CREATE TEMP MODEL statement for BQML"""
122+
source_sql = source.sql
123+
transform_sql = self.transform(*transforms) if transforms is not None else None
124+
options_sql = self.options(**options)
125+
121126
parts = [f"CREATE TEMP MODEL `{self._model_id}`"]
122127
if transform_sql:
123128
parts.append(transform_sql)
@@ -129,9 +134,11 @@ def create_model(
129134
def create_remote_model(
130135
self,
131136
connection_name: str,
132-
options_sql: Optional[str] = None,
137+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
133138
) -> str:
134139
"""Encode the CREATE TEMP MODEL statement for BQML remote model."""
140+
options_sql = self.options(**options)
141+
135142
parts = [f"CREATE TEMP MODEL `{self._model_id}`"]
136143
parts.append(self.connection(connection_name))
137144
if options_sql:
@@ -140,17 +147,19 @@ def create_remote_model(
140147

141148
def create_imported_model(
142149
self,
143-
options_sql: Optional[str] = None,
150+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
144151
) -> str:
145152
"""Encode the CREATE TEMP MODEL statement for BQML remote model."""
153+
options_sql = self.options(**options)
154+
146155
parts = [f"CREATE TEMP MODEL `{self._model_id}`"]
147156
if options_sql:
148157
parts.append(options_sql)
149158
return "\n".join(parts)
150159

151160

152161
class ModelManipulationSqlGenerator(BaseSqlGenerator):
153-
"""Sql generator for manipulating a model entity. Model name is the fully model path of project_id.dataset_id.model_id."""
162+
"""Sql generator for manipulating a model entity. Model name is the full model path of project_id.dataset_id.model_id."""
154163

155164
def __init__(self, model_name: str):
156165
self._model_name = model_name

tests/unit/ml/test_sql.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
1517
import pytest
1618

1719
import bigframes.ml.sql as ml_sql
20+
import bigframes.pandas as bpd
1821

1922

2023
@pytest.fixture(scope="session")
@@ -34,6 +37,14 @@ def model_manipulation_sql_generator() -> ml_sql.ModelManipulationSqlGenerator:
3437
)
3538

3639

40+
@pytest.fixture(scope="session")
41+
def mock_df():
42+
mock_df = mock.create_autospec(spec=bpd.DataFrame)
43+
mock_df.sql = "input_X_y_sql"
44+
45+
return mock_df
46+
47+
3748
def test_options_produces_correct_sql(base_sql_generator: ml_sql.BaseSqlGenerator):
3849
sql = base_sql_generator.options(
3950
model_type="lin_reg", input_label_cols=["col_a"], l1_reg=0.6
@@ -96,33 +107,44 @@ def test_label_encoder_produces_correct_sql(
96107

97108
def test_create_model_produces_correct_sql(
98109
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
110+
mock_df: bpd.DataFrame,
99111
):
100112
sql = model_creation_sql_generator.create_model(
101-
source_sql="my_source_sql",
102-
options_sql="my_options_sql",
113+
source=mock_df,
114+
options={"option_key1": "option_value1", "option_key2": 2},
103115
)
104116
assert (
105117
sql
106118
== """CREATE TEMP MODEL `my_model_id`
107-
my_options_sql
108-
AS my_source_sql"""
119+
OPTIONS(
120+
option_key1="option_value1",
121+
option_key2=2)
122+
AS input_X_y_sql"""
109123
)
110124

111125

112126
def test_create_model_transform_produces_correct_sql(
113127
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
128+
mock_df: bpd.DataFrame,
114129
):
115130
sql = model_creation_sql_generator.create_model(
116-
source_sql="my_source_sql",
117-
options_sql="my_options_sql",
118-
transform_sql="my_transform_sql",
131+
source=mock_df,
132+
options={"option_key1": "option_value1", "option_key2": 2},
133+
transforms=[
134+
"ML.STANDARD_SCALER(col_a) OVER(col_a) AS scaled_col_a",
135+
"ML.ONE_HOT_ENCODER(col_b) OVER(col_b) AS encoded_col_b",
136+
],
119137
)
120138
assert (
121139
sql
122140
== """CREATE TEMP MODEL `my_model_id`
123-
my_transform_sql
124-
my_options_sql
125-
AS my_source_sql"""
141+
TRANSFORM(
142+
ML.STANDARD_SCALER(col_a) OVER(col_a) AS scaled_col_a,
143+
ML.ONE_HOT_ENCODER(col_b) OVER(col_b) AS encoded_col_b)
144+
OPTIONS(
145+
option_key1="option_value1",
146+
option_key2=2)
147+
AS input_X_y_sql"""
126148
)
127149

128150

@@ -131,26 +153,30 @@ def test_create_remote_model_produces_correct_sql(
131153
):
132154
sql = model_creation_sql_generator.create_remote_model(
133155
connection_name="my_project.us.my_connection",
134-
options_sql="my_options_sql",
156+
options={"option_key1": "option_value1", "option_key2": 2},
135157
)
136158
assert (
137159
sql
138160
== """CREATE TEMP MODEL `my_model_id`
139161
REMOTE WITH CONNECTION `my_project.us.my_connection`
140-
my_options_sql"""
162+
OPTIONS(
163+
option_key1="option_value1",
164+
option_key2=2)"""
141165
)
142166

143167

144168
def test_create_imported_model_produces_correct_sql(
145169
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
146170
):
147171
sql = model_creation_sql_generator.create_imported_model(
148-
options_sql="my_options_sql",
172+
options={"option_key1": "option_value1", "option_key2": 2},
149173
)
150174
assert (
151175
sql
152176
== """CREATE TEMP MODEL `my_model_id`
153-
my_options_sql"""
177+
OPTIONS(
178+
option_key1="option_value1",
179+
option_key2=2)"""
154180
)
155181

156182

0 commit comments

Comments
 (0)