-
Notifications
You must be signed in to change notification settings - Fork 49
feat: add support for creating a Matrix Factorization model #1330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6783a0a
1d39560
e19c262
1bef4a2
d157cd7
e336bde
d5f713a
5e3e443
34a60bc
c116e8a
dedef39
e5165a9
05eb854
2787178
8c66e07
086b4dd
8ed3ccd
1b4eef9
7c371ac
7498c8c
55ef06a
29805b5
8de384a
647532b
b340c4f
580de41
29ee357
bac2ece
3f22c23
213f11d
aaf0d1f
4c90c1d
792bd64
ed279be
ba5beb3
86fb956
a29bbcf
8577833
a92007c
a808429
4b7b4db
8d55eac
9195658
faa4d6b
76a9934
bef7808
f18104d
9b39a99
0dd033d
60faed1
1f85b75
7efc63d
a457639
89790ac
a057a8f
512332e
741e749
f902131
310257d
408e807
19e423b
2c107df
c7c8eea
5f1a19a
68e308b
33f3069
1ff6aaa
446712b
c84dd7e
3473037
b3809e5
7e8a5b6
eba88d9
8599d88
8ab8818
b4d3578
a63cb90
3695f80
336bffd
bb6130a
e69438d
05da834
087953f
bfe9140
8d3599e
248a3b1
157daea
8912663
35a8c18
ac182be
ff58ff5
f0a6ba2
aaad5f5
b586c5c
04ddd5e
8e875ae
565138a
b39661f
c0ef08f
4b53b04
342cbd1
8812f33
24b8e0c
664de04
63e8e9c
3e52cd4
c2e9a5f
28c4602
1240eeb
46f1ea6
193b9c8
5a547f8
23d8fc8
ed99ad7
e305950
411fe1a
4273a99
b9f6a52
b92ed1f
aaf34eb
46601c4
0823db2
32917e5
e485d3b
6a27083
6e2d902
b65c637
93ac0fa
74ebe27
b2ebcf7
3f40763
2cbc2e3
0a5aefb
d484f77
366e0ab
1eaa708
56ee623
c942418
e0ef53e
5018182
c088a76
f9397f1
b439120
ffe0f33
b2698ef
69c8fba
8a614c5
9d71c86
c2b4795
cd20ffc
cf6e5be
da230b4
8927072
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
from typing import List, Literal, Optional, Union | ||
|
||
import bigframes_vendored.sklearn.decomposition._mf | ||
import bigframes_vendored.sklearn.decomposition._pca | ||
from google.cloud import bigquery | ||
|
||
|
@@ -27,7 +28,15 @@ | |
import bigframes.pandas as bpd | ||
import bigframes.session | ||
|
||
_BQML_PARAMS_MAPPING = {"svd_solver": "pcaSolver"} | ||
_BQML_PARAMS_MAPPING = { | ||
"svd_solver": "pcaSolver", | ||
"feedback_type": "feedbackType", | ||
"num_factors": "numFactors", | ||
"user_col": "userColumn", | ||
"item_col": "itemColumn", | ||
"_input_label_columns": "inputLabelColumns", | ||
"l2_reg": "l2Regularization", | ||
} | ||
|
||
|
||
@log_adapter.class_logger | ||
|
@@ -197,3 +206,159 @@ def score( | |
|
||
# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE. | ||
return self._bqml_model.evaluate() | ||
|
||
|
||
@log_adapter.class_logger | ||
class MatrixFactorization( | ||
base.UnsupervisedTrainablePredictor, | ||
bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization, | ||
): | ||
__doc__ = bigframes_vendored.sklearn.decomposition._mf.MatrixFactorization.__doc__ | ||
|
||
def __init__( | ||
self, | ||
*, | ||
feedback_type: Literal["explicit", "implicit"] = "explicit", | ||
num_factors: int, | ||
rey-esp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
user_col: str, | ||
item_col: str, | ||
rating_col: str = "rating", | ||
Comment on lines
+223
to
+225
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @GarrettWu @shuoweil I see in #1282 you ended up passing in "id_col" as a separate argument to Note: MatrixFactorization differs somewhat from that application in that normally in scikit-learn one would have a "sparse matrix" data type (e.g. https://docs.scipy.org/doc/scipy/reference/sparse.html) where rows/cols/values would all be bundled up in one object, similar to how we are using the bigframes DataFrame for this purpose. |
||
# TODO: Add support for hyperparameter tuning. | ||
l2_reg: float = 1.0, | ||
): | ||
|
||
feedback_type = feedback_type.lower() # type: ignore | ||
if feedback_type not in ("explicit", "implicit"): | ||
raise ValueError("Expected feedback_type to be `explicit` or `implicit`.") | ||
|
||
self.feedback_type = feedback_type | ||
|
||
if not isinstance(num_factors, int): | ||
raise TypeError( | ||
f"Expected num_factors to be an int, but got {type(num_factors)}." | ||
) | ||
|
||
if num_factors < 0: | ||
raise ValueError( | ||
f"Expected num_factors to be a positive integer, but got {num_factors}." | ||
) | ||
|
||
self.num_factors = num_factors | ||
|
||
if not isinstance(user_col, str): | ||
raise TypeError(f"Expected user_col to be a str, but got {type(user_col)}.") | ||
|
||
self.user_col = user_col | ||
|
||
if not isinstance(item_col, str): | ||
raise TypeError(f"Expected item_col to be STR, but got {type(item_col)}.") | ||
|
||
self.item_col = item_col | ||
|
||
if not isinstance(rating_col, str): | ||
raise TypeError( | ||
f"Expected rating_col to be a str, but got {type(rating_col)}." | ||
) | ||
|
||
self._input_label_columns = [rating_col] | ||
|
||
if not isinstance(l2_reg, (float, int)): | ||
raise TypeError( | ||
f"Expected l2_reg to be a float or int, but got {type(l2_reg)}." | ||
) | ||
|
||
self.l2_reg = l2_reg | ||
self._bqml_model: Optional[core.BqmlModel] = None | ||
self._bqml_model_factory = globals.bqml_model_factory() | ||
rey-esp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def rating_col(self) -> str: | ||
"""str: The rating column name. Defaults to 'rating'.""" | ||
return self._input_label_columns[0] | ||
|
||
@classmethod | ||
def _from_bq( | ||
cls, session: bigframes.session.Session, bq_model: bigquery.Model | ||
) -> MatrixFactorization: | ||
assert bq_model.model_type == "MATRIX_FACTORIZATION" | ||
|
||
kwargs = utils.retrieve_params_from_bq_model( | ||
cls, bq_model, _BQML_PARAMS_MAPPING | ||
) | ||
|
||
model = cls(**kwargs) | ||
model._bqml_model = core.BqmlModel(session, bq_model) | ||
return model | ||
|
||
@property | ||
def _bqml_options(self) -> dict: | ||
"""The model options as they will be set for BQML""" | ||
options: dict = { | ||
"model_type": "matrix_factorization", | ||
"feedback_type": self.feedback_type, | ||
"user_col": self.user_col, | ||
"item_col": self.item_col, | ||
"rating_col": self.rating_col, | ||
"l2_reg": self.l2_reg, | ||
} | ||
|
||
if self.num_factors is not None: | ||
options["num_factors"] = self.num_factors | ||
|
||
return options | ||
|
||
def _fit( | ||
self, | ||
X: utils.ArrayType, | ||
y=None, | ||
transforms: Optional[List[str]] = None, | ||
) -> MatrixFactorization: | ||
if y is not None: | ||
raise ValueError( | ||
"Label column not supported for Matrix Factorization model but y was not `None`" | ||
) | ||
|
||
(X,) = utils.batch_convert_to_dataframe(X) | ||
|
||
self._bqml_model = self._bqml_model_factory.create_model( | ||
X_train=X, | ||
transforms=transforms, | ||
options=self._bqml_options, | ||
) | ||
return self | ||
|
||
def predict(self, X: utils.ArrayType) -> bpd.DataFrame: | ||
if not self._bqml_model: | ||
raise RuntimeError("A model must be fitted before recommend") | ||
|
||
(X,) = utils.batch_convert_to_dataframe(X, session=self._bqml_model.session) | ||
|
||
return self._bqml_model.recommend(X) | ||
|
||
def to_gbq(self, model_name: str, replace: bool = False) -> MatrixFactorization: | ||
"""Save the model to BigQuery. | ||
|
||
Args: | ||
model_name (str): | ||
The name of the model. | ||
replace (bool, default False): | ||
Determine whether to replace if the model already exists. Default to False. | ||
|
||
Returns: | ||
MatrixFactorization: Saved model.""" | ||
if not self._bqml_model: | ||
raise RuntimeError("A model must be fitted before it can be saved") | ||
|
||
new_model = self._bqml_model.copy(model_name, replace) | ||
return new_model.session.read_gbq_model(model_name) | ||
|
||
def score( | ||
self, | ||
X=None, | ||
y=None, | ||
) -> bpd.DataFrame: | ||
if not self._bqml_model: | ||
raise RuntimeError("A model must be fitted before score") | ||
|
||
# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE. | ||
return self._bqml_model.evaluate() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{"user_id": 1, "item_id": 2, "rating": 4.0} | ||
{"user_id": 1, "item_id": 5, "rating": 3.0} | ||
{"user_id": 2, "item_id": 1, "rating": 5.0} | ||
{"user_id": 2, "item_id": 3, "rating": 2.0} | ||
{"user_id": 3, "item_id": 4, "rating": 4.5} | ||
{"user_id": 3, "item_id": 7, "rating": 3.5} | ||
{"user_id": 4, "item_id": 2, "rating": 1.0} | ||
{"user_id": 4, "item_id": 8, "rating": 5.0} | ||
{"user_id": 5, "item_id": 3, "rating": 4.0} | ||
{"user_id": 5, "item_id": 9, "rating": 2.5} | ||
{"user_id": 6, "item_id": 1, "rating": 3.0} | ||
{"user_id": 6, "item_id": 6, "rating": 4.5} | ||
{"user_id": 7, "item_id": 5, "rating": 5.0} | ||
{"user_id": 7, "item_id": 10, "rating": 1.5} | ||
{"user_id": 8, "item_id": 4, "rating": 2.0} | ||
{"user_id": 8, "item_id": 7, "rating": 4.0} | ||
{"user_id": 9, "item_id": 2, "rating": 3.5} | ||
{"user_id": 9, "item_id": 9, "rating": 5.0} | ||
{"user_id": 10, "item_id": 3, "rating": 4.5} | ||
{"user_id": 10, "item_id": 8, "rating": 2.5} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[ | ||
{ | ||
"mode": "NULLABLE", | ||
"name": "user_id", | ||
"type": "STRING" | ||
}, | ||
{ | ||
"mode": "NULLABLE", | ||
"name": "item_id", | ||
"type": "INT64" | ||
}, | ||
{ | ||
"mode": "NULLABLE", | ||
"name": "rating", | ||
"type": "FLOAT" | ||
} | ||
] |
Uh oh!
There was an error while loading. Please reload this page.