Skip to content

Commit 416171a

Browse files
authored
feat!: model.predict returns all the columns (#204)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent e8532b1 commit 416171a

21 files changed

+2737
-2275
lines changed

bigframes/ml/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import cast, Dict, List, Optional, Union
20+
from typing import Dict, List, Optional, Union
2121

2222
from google.cloud import bigquery
2323

@@ -92,7 +92,7 @@ def predict(
9292

9393
(X,) = utils.convert_to_dataframe(X)
9494

95-
return cast(bpd.DataFrame, self._bqml_model.predict(X)[["CENTROID_ID"]])
95+
return self._bqml_model.predict(X)
9696

9797
def to_gbq(self, model_name: str, replace: bool = False) -> KMeans:
9898
"""Save the model to BigQuery.

bigframes/ml/decomposition.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import cast, List, Optional, Union
20+
from typing import List, Optional, Union
2121

2222
from google.cloud import bigquery
2323

@@ -106,12 +106,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
106106

107107
(X,) = utils.convert_to_dataframe(X)
108108

109-
return cast(
110-
bpd.DataFrame,
111-
self._bqml_model.predict(X)[
112-
["principal_component_" + str(i + 1) for i in range(self.n_components)]
113-
],
114-
)
109+
return self._bqml_model.predict(X)
115110

116111
def to_gbq(self, model_name: str, replace: bool = False) -> PCA:
117112
"""Save the model to BigQuery.

bigframes/ml/ensemble.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import cast, Dict, List, Literal, Optional, Union
20+
from typing import Dict, List, Literal, Optional, Union
2121

2222
from google.cloud import bigquery
2323

@@ -168,16 +168,7 @@ def predict(
168168
raise RuntimeError("A model must be fitted before predict")
169169
(X,) = utils.convert_to_dataframe(X)
170170

171-
df = self._bqml_model.predict(X)
172-
return cast(
173-
bpd.DataFrame,
174-
df[
175-
[
176-
cast(str, field.name)
177-
for field in self._bqml_model.model.label_columns
178-
]
179-
],
180-
)
171+
return self._bqml_model.predict(X)
181172

182173
def score(
183174
self,
@@ -328,19 +319,9 @@ def _fit(
328319
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
329320
if not self._bqml_model:
330321
raise RuntimeError("A model must be fitted before predict")
331-
332322
(X,) = utils.convert_to_dataframe(X)
333323

334-
df = self._bqml_model.predict(X)
335-
return cast(
336-
bpd.DataFrame,
337-
df[
338-
[
339-
cast(str, field.name)
340-
for field in self._bqml_model.model.label_columns
341-
]
342-
],
343-
)
324+
return self._bqml_model.predict(X)
344325

345326
def score(
346327
self,
@@ -486,19 +467,9 @@ def predict(
486467
) -> bpd.DataFrame:
487468
if not self._bqml_model:
488469
raise RuntimeError("A model must be fitted before predict")
489-
490470
(X,) = utils.convert_to_dataframe(X)
491471

492-
df = self._bqml_model.predict(X)
493-
return cast(
494-
bpd.DataFrame,
495-
df[
496-
[
497-
cast(str, field.name)
498-
for field in self._bqml_model.model.label_columns
499-
]
500-
],
501-
)
472+
return self._bqml_model.predict(X)
502473

503474
def score(
504475
self,
@@ -661,19 +632,9 @@ def predict(
661632
) -> bpd.DataFrame:
662633
if not self._bqml_model:
663634
raise RuntimeError("A model must be fitted before predict")
664-
665635
(X,) = utils.convert_to_dataframe(X)
666636

667-
df = self._bqml_model.predict(X)
668-
return cast(
669-
bpd.DataFrame,
670-
df[
671-
[
672-
cast(str, field.name)
673-
for field in self._bqml_model.model.label_columns
674-
]
675-
],
676-
)
637+
return self._bqml_model.predict(X)
677638

678639
def score(
679640
self,

bigframes/ml/forecasting.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import cast, Dict, List, Optional, Union
19+
from typing import Dict, List, Optional, Union
2020

2121
from google.cloud import bigquery
2222

2323
import bigframes
2424
from bigframes.ml import base, core, globals, utils
2525
import bigframes.pandas as bpd
2626

27-
_PREDICT_OUTPUT_COLUMNS = ["forecast_timestamp", "forecast_value"]
28-
2927

3028
class ARIMAPlus(base.SupervisedTrainablePredictor):
3129
"""Time Series ARIMA Plus model."""
@@ -100,10 +98,7 @@ def predict(self, X=None) -> bpd.DataFrame:
10098
if not self._bqml_model:
10199
raise RuntimeError("A model must be fitted before predict")
102100

103-
return cast(
104-
bpd.DataFrame,
105-
self._bqml_model.forecast()[_PREDICT_OUTPUT_COLUMNS],
106-
)
101+
return self._bqml_model.forecast()
107102

108103
def score(
109104
self,

bigframes/ml/imported.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
7878

7979
(X,) = utils.convert_to_dataframe(X)
8080

81-
df = self._bqml_model.predict(X)
82-
return cast(
83-
bpd.DataFrame,
84-
df[
85-
[
86-
cast(str, field.name)
87-
for field in self._bqml_model.model.label_columns
88-
]
89-
],
90-
)
81+
return self._bqml_model.predict(X)
9182

9283
def to_gbq(self, model_name: str, replace: bool = False) -> TensorFlowModel:
9384
"""Save the model to BigQuery.
@@ -161,16 +152,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
161152

162153
(X,) = utils.convert_to_dataframe(X)
163154

164-
df = self._bqml_model.predict(X)
165-
return cast(
166-
bpd.DataFrame,
167-
df[
168-
[
169-
cast(str, field.name)
170-
for field in self._bqml_model.model.label_columns
171-
]
172-
],
173-
)
155+
return self._bqml_model.predict(X)
174156

175157
def to_gbq(self, model_name: str, replace: bool = False) -> ONNXModel:
176158
"""Save the model to BigQuery.

bigframes/ml/linear_model.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import cast, Dict, List, Literal, Optional, Union
20+
from typing import Dict, List, Literal, Optional, Union
2121

2222
from google.cloud import bigquery
2323

@@ -145,16 +145,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
145145

146146
(X,) = utils.convert_to_dataframe(X)
147147

148-
df = self._bqml_model.predict(X)
149-
return cast(
150-
bpd.DataFrame,
151-
df[
152-
[
153-
cast(str, field.name)
154-
for field in self._bqml_model.model.label_columns
155-
]
156-
],
157-
)
148+
return self._bqml_model.predict(X)
158149

159150
def score(
160151
self,
@@ -267,16 +258,7 @@ def predict(
267258

268259
(X,) = utils.convert_to_dataframe(X)
269260

270-
df = self._bqml_model.predict(X)
271-
return cast(
272-
bpd.DataFrame,
273-
df[
274-
[
275-
cast(str, field.name)
276-
for field in self._bqml_model.model.label_columns
277-
]
278-
],
279-
)
261+
return self._bqml_model.predict(X)
280262

281263
def score(
282264
self,

bigframes/ml/llm.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def predict(
149149
150150
151151
Returns:
152-
bigframes.dataframe.DataFrame: Output DataFrame with only 1 column as the output text results."""
152+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
153+
"""
153154

154155
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
155156
if temperature < 0.0 or temperature > 1.0:
@@ -181,11 +182,7 @@ def predict(
181182
"top_p": top_p,
182183
"flatten_json_output": True,
183184
}
184-
df = self._bqml_model.generate_text(X, options)
185-
return cast(
186-
bpd.DataFrame,
187-
df[[_TEXT_GENERATE_RESULT_COLUMN]],
188-
)
185+
return self._bqml_model.generate_text(X, options)
189186

190187

191188
class PaLM2TextEmbeddingGenerator(base.Predictor):
@@ -269,7 +266,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
269266
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
270267
271268
Returns:
272-
bigframes.dataframe.DataFrame: Output DataFrame with only 1 column as the output embedding results
269+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
273270
"""
274271

275272
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
@@ -287,8 +284,4 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
287284
options = {
288285
"flatten_json_output": True,
289286
}
290-
df = self._bqml_model.generate_text_embedding(X, options)
291-
return cast(
292-
bpd.DataFrame,
293-
df[[_EMBED_TEXT_RESULT_COLUMN]],
294-
)
287+
return self._bqml_model.generate_text_embedding(X, options)

0 commit comments

Comments
 (0)