Skip to content

feat: add support for creating a Matrix Factorization model #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 153 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
153 commits
Select commit Hold shift + click to select a range
6783a0a
docs: update title of pypi notebook example to reflect use of the PyP…
tswast Sep 4, 2024
1d39560
feat: add support for creating a Matrix Factorization model
rey-esp Jan 27, 2025
e19c262
feat: add support for creating a Matrix Factorization model
rey-esp Jan 27, 2025
1bef4a2
feat: add support for creating a Matrix Factorization model
rey-esp Jan 27, 2025
d157cd7
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 28, 2025
e336bde
Update bigframes/ml/decomposition.py
rey-esp Jan 28, 2025
d5f713a
Update bigframes/ml/decomposition.py
rey-esp Jan 28, 2025
5e3e443
Update bigframes/ml/decomposition.py
rey-esp Jan 28, 2025
34a60bc
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 28, 2025
c116e8a
rating_col
rey-esp Jan 28, 2025
dedef39
(nearly) complete class
rey-esp Jan 28, 2025
e5165a9
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 28, 2025
05eb854
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 28, 2025
2787178
removem print()
rey-esp Jan 28, 2025
8c66e07
removem print()
rey-esp Jan 28, 2025
086b4dd
adding recommend
rey-esp Jan 29, 2025
8ed3ccd
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 29, 2025
1b4eef9
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 29, 2025
7c371ac
remove hyper parameter runing references
rey-esp Jan 30, 2025
7498c8c
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 30, 2025
55ef06a
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Jan 30, 2025
29805b5
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 4, 2025
8de384a
swap predict in _mf for recommend
rey-esp Feb 4, 2025
647532b
recommend -> predict
rey-esp Feb 4, 2025
b340c4f
update predict doc string
rey-esp Feb 4, 2025
580de41
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 4, 2025
29ee357
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 5, 2025
bac2ece
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 6, 2025
3f22c23
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Feb 6, 2025
213f11d
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 6, 2025
aaf0d1f
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 10, 2025
4c90c1d
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 10, 2025
792bd64
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Feb 10, 2025
ed279be
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 11, 2025
ba5beb3
preparing test files
rey-esp Feb 12, 2025
86fb956
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 12, 2025
a29bbcf
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 13, 2025
8577833
add test data
rey-esp Feb 13, 2025
a92007c
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 19, 2025
a808429
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 20, 2025
4b7b4db
new error: to_gbq column names need to be changed?
rey-esp Feb 21, 2025
8d55eac
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 24, 2025
9195658
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 24, 2025
faa4d6b
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 24, 2025
76a9934
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Feb 24, 2025
bef7808
Delete demo.ipynb
rey-esp Feb 24, 2025
f18104d
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 25, 2025
9b39a99
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Feb 25, 2025
0dd033d
passing system test
rey-esp Feb 25, 2025
60faed1
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 26, 2025
1f85b75
preparing to add unit tests
rey-esp Feb 26, 2025
7efc63d
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 27, 2025
a457639
2 out of 3 (so far) passing unit tests
rey-esp Feb 27, 2025
89790ac
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Feb 28, 2025
a057a8f
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 3, 2025
512332e
attempted mocking
rey-esp Mar 3, 2025
741e749
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 4, 2025
f902131
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 4, 2025
310257d
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Mar 4, 2025
408e807
fix tests
rey-esp Mar 4, 2025
19e423b
new test file for model creation unit tests
rey-esp Mar 4, 2025
2c107df
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 5, 2025
c7c8eea
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 5, 2025
5f1a19a
add unit tests for num_factors, user_col, and item_col
rey-esp Mar 6, 2025
68e308b
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Mar 6, 2025
33f3069
Update tests/unit/ml/test_matrix_factorization.py
rey-esp Mar 7, 2025
1ff6aaa
Update tests/unit/ml/test_matrix_factorization.py
rey-esp Mar 7, 2025
446712b
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 7, 2025
c84dd7e
uncomment one test
rey-esp Mar 7, 2025
3473037
uncomment test
rey-esp Mar 7, 2025
b3809e5
uncomment test
rey-esp Mar 7, 2025
7e8a5b6
uncomment test
rey-esp Mar 7, 2025
eba88d9
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 10, 2025
8599d88
nearly all tests
rey-esp Mar 10, 2025
8ab8818
tests complete and passing
rey-esp Mar 10, 2025
b4d3578
seeing if test causes kokoro failure
rey-esp Mar 10, 2025
a63cb90
uncomment test-kokoro still failing
rey-esp Mar 10, 2025
3695f80
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 10, 2025
336bffd
Merge branch 'tswast-patch-1' into b338873783-matrix-factorization
rey-esp Mar 10, 2025
bb6130a
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 11, 2025
e69438d
remove comment
rey-esp Mar 11, 2025
05da834
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 11, 2025
087953f
fix test
rey-esp Mar 11, 2025
bfe9140
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 11, 2025
8d3599e
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 11, 2025
248a3b1
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
157daea
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
8912663
test kokoro
rey-esp Mar 12, 2025
35a8c18
test_decomposition.py failing and now feedback_type attr does not exist
rey-esp Mar 12, 2025
ac182be
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
ff58ff5
passing tests
rey-esp Mar 12, 2025
f0a6ba2
Update bigframes/ml/decomposition.py
rey-esp Mar 12, 2025
aaad5f5
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
b586c5c
Update tests/system/large/ml/test_decomposition.py
rey-esp Mar 12, 2025
04ddd5e
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
8e875ae
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 12, 2025
565138a
doc attempt - _mf.py example
rey-esp Mar 12, 2025
b39661f
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Mar 12, 2025
c0ef08f
feedback_type case ignore
rey-esp Mar 13, 2025
4b53b04
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 17, 2025
342cbd1
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 18, 2025
8812f33
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 18, 2025
24b8e0c
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 18, 2025
664de04
Update _mf.py - remove global_explain()
rey-esp Mar 18, 2025
63e8e9c
fit
rey-esp Mar 18, 2025
3e52cd4
pull?
rey-esp Mar 18, 2025
c2e9a5f
W
rey-esp Mar 18, 2025
28c4602
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 19, 2025
1240eeb
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 24, 2025
46f1ea6
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 24, 2025
193b9c8
fix docs (maybe)
rey-esp Mar 24, 2025
5a547f8
Update test_matrix_factorization.py with updated error messages
rey-esp Mar 24, 2025
23d8fc8
ilnt
rey-esp Mar 24, 2025
ed99ad7
Update test_matrix_factorization.py - add 'f'
rey-esp Mar 24, 2025
e305950
improve errors and update tests
rey-esp Mar 24, 2025
411fe1a
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
4273a99
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
b9f6a52
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
b92ed1f
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
aaf34eb
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
46601c4
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 25, 2025
0823db2
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 27, 2025
32917e5
Update tests/system/large/ml/test_decomposition.py
rey-esp Mar 27, 2025
e485d3b
Update bigframes/ml/decomposition.py - num_factors error messsage
rey-esp Mar 27, 2025
6a27083
Update bigframes/ml/decomposition.py - user_col error message
rey-esp Mar 27, 2025
6e2d902
Update bigframes/ml/decomposition.py - rating_col error message
rey-esp Mar 27, 2025
b65c637
Update bigframes/ml/decomposition.py - l2_reg error msg
rey-esp Mar 27, 2025
93ac0fa
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 27, 2025
74ebe27
fix tests to match updated error messages
rey-esp Mar 27, 2025
b2ebcf7
Merge branch 'b338873783-matrix-factorization' of github.com:googleap…
rey-esp Mar 27, 2025
3f40763
Update third_party/bigframes_vendored/sklearn/decomposition/_mf.py - …
rey-esp Mar 27, 2025
2cbc2e3
Update third_party/bigframes_vendored/sklearn/decomposition/_mf.py - …
rey-esp Mar 27, 2025
0a5aefb
Update third_party/bigframes_vendored/sklearn/decomposition/_mf.py - …
rey-esp Mar 27, 2025
d484f77
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 27, 2025
366e0ab
Update third_party/bigframes_vendored/sklearn/decomposition/_mf.py
tswast Mar 28, 2025
1eaa708
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 31, 2025
56ee623
remove errors and tests
rey-esp Mar 31, 2025
c942418
Update bigframes/ml/decomposition.py
rey-esp Mar 31, 2025
e0ef53e
Update bigframes/ml/decomposition.py
rey-esp Mar 31, 2025
5018182
Update bigframes/ml/decomposition.py
rey-esp Mar 31, 2025
c088a76
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 31, 2025
f9397f1
passing system test
rey-esp Mar 31, 2025
b439120
E AssertionError: expected call not found.
rey-esp Mar 31, 2025
ffe0f33
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 31, 2025
b2698ef
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Mar 31, 2025
69c8fba
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Apr 1, 2025
8a614c5
same # of elements in each
rey-esp Apr 1, 2025
9d71c86
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Apr 1, 2025
c2b4795
attempt
rey-esp Apr 1, 2025
cd20ffc
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Apr 1, 2025
cf6e5be
doc fix
rey-esp Apr 1, 2025
da230b4
doc fix
rey-esp Apr 1, 2025
8927072
Merge branch 'main' into b338873783-matrix-factorization
rey-esp Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def model(self) -> bigquery.Model:
"""Get the BQML model associated with this wrapper"""
return self._model

def recommend(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
self._model_manipulation_sql_generator.ml_recommend,
)

def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
Expand Down
167 changes: 166 additions & 1 deletion bigframes/ml/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import List, Literal, Optional, Union

import bigframes_vendored.sklearn.decomposition._mf
import bigframes_vendored.sklearn.decomposition._pca
from google.cloud import bigquery

Expand All @@ -27,7 +28,15 @@
import bigframes.pandas as bpd
import bigframes.session

_BQML_PARAMS_MAPPING = {"svd_solver": "pcaSolver"}
_BQML_PARAMS_MAPPING = {
"svd_solver": "pcaSolver",
"feedback_type": "feedbackType",
"num_factors": "numFactors",
"user_col": "userColumn",
"item_col": "itemColumn",
"_input_label_columns": "inputLabelColumns",
"l2_reg": "l2Regularization",
}


@log_adapter.class_logger
Expand Down Expand Up @@ -197,3 +206,159 @@ def score(

# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
return self._bqml_model.evaluate()


@log_adapter.class_logger
class MatrixFactorization(
base.UnsupervisedTrainablePredictor,
bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization,
):
__doc__ = bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization.__doc__

def __init__(
self,
*,
feedback_type: Literal["explicit", "implicit"] = "explicit",
num_factors: int,
user_col: str,
item_col: str,
rating_col: str = "rating",
Comment on lines +223 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GarrettWu @shuoweil I see in #1282 you ended up passing in "id_col" as a separate argument to fit() instead of the class constructor. Is this a pattern you would recommend here?

Note: MatrixFactorization differs somewhat from that application in that normally in scikit-learn one would have a "sparse matrix" data type (e.g. https://docs.scipy.org/doc/scipy/reference/sparse.html) where rows/cols/values would all be bundled up in one object, similar to how we are using the bigframes DataFrame for this purpose.

# TODO: Add support for hyperparameter tuning.
l2_reg: float = 1.0,
):

feedback_type = feedback_type.lower() # type: ignore
if feedback_type not in ("explicit", "implicit"):
raise ValueError("Expected feedback_type to be `explicit` or `implicit`.")

self.feedback_type = feedback_type

if not isinstance(num_factors, int):
raise TypeError(
f"Expected num_factors to be an int, but got {type(num_factors)}."
)

if num_factors < 0:
raise ValueError(
f"Expected num_factors to be a positive integer, but got {num_factors}."
)

self.num_factors = num_factors

if not isinstance(user_col, str):
raise TypeError(f"Expected user_col to be a str, but got {type(user_col)}.")

self.user_col = user_col

if not isinstance(item_col, str):
raise TypeError(f"Expected item_col to be STR, but got {type(item_col)}.")

self.item_col = item_col

if not isinstance(rating_col, str):
raise TypeError(
f"Expected rating_col to be a str, but got {type(rating_col)}."
)

self._input_label_columns = [rating_col]

if not isinstance(l2_reg, (float, int)):
raise TypeError(
f"Expected l2_reg to be a float or int, but got {type(l2_reg)}."
)

self.l2_reg = l2_reg
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()

@property
def rating_col(self) -> str:
"""str: The rating column name. Defaults to 'rating'."""
return self._input_label_columns[0]

@classmethod
def _from_bq(
cls, session: bigframes.session.Session, bq_model: bigquery.Model
) -> MatrixFactorization:
assert bq_model.model_type == "MATRIX_FACTORIZATION"

kwargs = utils.retrieve_params_from_bq_model(
cls, bq_model, _BQML_PARAMS_MAPPING
)

model = cls(**kwargs)
model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _bqml_options(self) -> dict:
"""The model options as they will be set for BQML"""
options: dict = {
"model_type": "matrix_factorization",
"feedback_type": self.feedback_type,
"user_col": self.user_col,
"item_col": self.item_col,
"rating_col": self.rating_col,
"l2_reg": self.l2_reg,
}

if self.num_factors is not None:
options["num_factors"] = self.num_factors

return options

def _fit(
self,
X: utils.ArrayType,
y=None,
transforms: Optional[List[str]] = None,
) -> MatrixFactorization:
if y is not None:
raise ValueError(
"Label column not supported for Matrix Factorization model but y was not `None`"
)

(X,) = utils.batch_convert_to_dataframe(X)

self._bqml_model = self._bqml_model_factory.create_model(
X_train=X,
transforms=transforms,
options=self._bqml_options,
)
return self

def predict(self, X: utils.ArrayType) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("A model must be fitted before recommend")

(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session)

return self._bqml_model.recommend(X)

def to_gbq(self, model_name: str, replace: bool = False) -> MatrixFactorization:
"""Save the model to BigQuery.

Args:
model_name (str):
The name of the model.
replace (bool, default False):
Determine whether to replace if the model already exists. Default to False.

Returns:
MatrixFactorization: Saved model."""
if not self._bqml_model:
raise RuntimeError("A model must be fitted before it can be saved")

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)

def score(
self,
X=None,
y=None,
) -> bpd.DataFrame:
if not self._bqml_model:
raise RuntimeError("A model must be fitted before score")

# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
return self._bqml_model.evaluate()
2 changes: 2 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"LINEAR_REGRESSION": linear_model.LinearRegression,
"LOGISTIC_REGRESSION": linear_model.LogisticRegression,
"KMEANS": cluster.KMeans,
"MATRIX_FACTORIZATION": decomposition.MatrixFactorization,
"PCA": decomposition.PCA,
"BOOSTED_TREE_REGRESSOR": ensemble.XGBRegressor,
"BOOSTED_TREE_CLASSIFIER": ensemble.XGBClassifier,
Expand Down Expand Up @@ -80,6 +81,7 @@
def from_bq(
session: bigframes.session.Session, bq_model: bigquery.Model
) -> Union[
decomposition.MatrixFactorization,
decomposition.PCA,
cluster.KMeans,
linear_model.LinearRegression,
Expand Down
5 changes: 5 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def alter_model(
return "\n".join(parts)

# ML prediction TVFs
def ml_recommend(self, source_sql: str) -> str:
"""Encode ML.RECOMMEND for BQML"""
return f"""SELECT * FROM ML.RECOMMEND(MODEL {self._model_ref_sql()},
({source_sql}))"""

def ml_predict(self, source_sql: str) -> str:
"""Encode ML.PREDICT for BQML"""
return f"""SELECT * FROM ML.PREDICT(MODEL {self._model_ref_sql()},
Expand Down
20 changes: 20 additions & 0 deletions tests/data/ratings.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{"user_id": 1, "item_id": 2, "rating": 4.0}
{"user_id": 1, "item_id": 5, "rating": 3.0}
{"user_id": 2, "item_id": 1, "rating": 5.0}
{"user_id": 2, "item_id": 3, "rating": 2.0}
{"user_id": 3, "item_id": 4, "rating": 4.5}
{"user_id": 3, "item_id": 7, "rating": 3.5}
{"user_id": 4, "item_id": 2, "rating": 1.0}
{"user_id": 4, "item_id": 8, "rating": 5.0}
{"user_id": 5, "item_id": 3, "rating": 4.0}
{"user_id": 5, "item_id": 9, "rating": 2.5}
{"user_id": 6, "item_id": 1, "rating": 3.0}
{"user_id": 6, "item_id": 6, "rating": 4.5}
{"user_id": 7, "item_id": 5, "rating": 5.0}
{"user_id": 7, "item_id": 10, "rating": 1.5}
{"user_id": 8, "item_id": 4, "rating": 2.0}
{"user_id": 8, "item_id": 7, "rating": 4.0}
{"user_id": 9, "item_id": 2, "rating": 3.5}
{"user_id": 9, "item_id": 9, "rating": 5.0}
{"user_id": 10, "item_id": 3, "rating": 4.5}
{"user_id": 10, "item_id": 8, "rating": 2.5}
17 changes: 17 additions & 0 deletions tests/data/ratings_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[
{
"mode": "NULLABLE",
"name": "user_id",
"type": "STRING"
},
{
"mode": "NULLABLE",
"name": "item_id",
"type": "INT64"
},
{
"mode": "NULLABLE",
"name": "rating",
"type": "FLOAT"
}
]
14 changes: 14 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def load_test_data_tables(
("repeated", "repeated_schema.json", "repeated.jsonl"),
("json", "json_schema.json", "json.jsonl"),
("penguins", "penguins_schema.json", "penguins.jsonl"),
("ratings", "ratings_schema.json", "ratings.jsonl"),
("time_series", "time_series_schema.json", "time_series.jsonl"),
("hockey_players", "hockey_players.json", "hockey_players.jsonl"),
("matrix_2by3", "matrix_2by3.json", "matrix_2by3.jsonl"),
Expand Down Expand Up @@ -416,6 +417,11 @@ def penguins_table_id(test_data_tables) -> str:
return test_data_tables["penguins"]


@pytest.fixture(scope="session")
def ratings_table_id(test_data_tables) -> str:
return test_data_tables["ratings"]


@pytest.fixture(scope="session")
def urban_areas_table_id(test_data_tables) -> str:
return test_data_tables["urban_areas"]
Expand Down Expand Up @@ -769,6 +775,14 @@ def penguins_df_null_index(
return unordered_session.read_gbq(penguins_table_id)


@pytest.fixture(scope="session")
def ratings_df_default_index(
ratings_table_id: str, session: bigframes.Session
) -> bigframes.dataframe.DataFrame:
"""DataFrame pointing at test data."""
return session.read_gbq(ratings_table_id)


@pytest.fixture(scope="session")
def time_series_df_default_index(
time_series_table_id: str, session: bigframes.Session
Expand Down
46 changes: 46 additions & 0 deletions tests/system/large/ml/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,49 @@ def test_decomposition_configure_fit_load_none_component(
in reloaded_model._bqml_model.model_name
)
assert reloaded_model.n_components == 7


def test_decomposition_mf_configure_fit_load(
session, ratings_df_default_index, dataset_id
):
model = decomposition.MatrixFactorization(
num_factors=6,
feedback_type="explicit",
user_col="user_id",
item_col="item_id",
rating_col="rating",
l2_reg=9.83,
)

model.fit(ratings_df_default_index)

reloaded_model = model.to_gbq(
f"{dataset_id}.temp_configured_mf_model", replace=True
)

new_ratings = session.read_pandas(
pd.DataFrame(
{
"user_id": ["11", "12", "13"],
"item_id": [1, 2, 3],
"rating": [1.0, 2.0, 3.0],
}
)
)

reloaded_model.score(new_ratings)

result = reloaded_model.predict(new_ratings).to_pandas()

assert reloaded_model._bqml_model is not None
assert (
f"{dataset_id}.temp_configured_mf_model"
in reloaded_model._bqml_model.model_name
)
assert result is not None
assert reloaded_model.feedback_type == "explicit"
assert reloaded_model.num_factors == 6
assert reloaded_model.user_col == "user_id"
assert reloaded_model.item_col == "item_id"
assert reloaded_model.rating_col == "rating"
assert reloaded_model.l2_reg == 9.83
Loading