Skip to content

Commit 126f566

Browse files
authored
feat: add ml.metrics.pairwise.cosine_similarity function (#374)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 90caf86 commit 126f566

File tree

9 files changed

+260
-50
lines changed

9 files changed

+260
-50
lines changed

bigframes/ml/core.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import datetime
20-
from typing import Callable, cast, Iterable, Mapping, Optional, Union
20+
from typing import Callable, cast, Iterable, Literal, Mapping, Optional, Union
2121
import uuid
2222

2323
from google.cloud import bigquery
@@ -28,34 +28,12 @@
2828
import bigframes.pandas as bpd
2929

3030

31-
class BqmlModel:
32-
"""Represents an existing BQML model in BigQuery.
33-
34-
Wraps the BQML API and SQL interface to expose the functionality needed for
35-
BigQuery DataFrames ML.
36-
"""
31+
class BaseBqml:
32+
"""Base class for BQML functionalities."""
3733

38-
def __init__(self, session: bigframes.Session, model: bigquery.Model):
34+
def __init__(self, session: bigframes.Session):
3935
self._session = session
40-
self._model = model
41-
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
42-
self.model_name
43-
)
44-
45-
@property
46-
def session(self) -> bigframes.Session:
47-
"""Get the BigQuery DataFrames session that this BQML model wrapper is tied to"""
48-
return self._session
49-
50-
@property
51-
def model_name(self) -> str:
52-
"""Get the fully qualified name of the model, i.e. project_id.dataset_id.model_id"""
53-
return f"{self._model.project}.{self._model.dataset_id}.{self._model.model_id}"
54-
55-
@property
56-
def model(self) -> bigquery.Model:
57-
"""Get the BQML model associated with this wrapper"""
58-
return self._model
36+
self._base_sql_generator = ml_sql.BaseSqlGenerator()
5937

6038
def _apply_sql(
6139
self,
@@ -84,6 +62,71 @@ def _apply_sql(
8462

8563
return df
8664

65+
def distance(
66+
self,
67+
x: bpd.DataFrame,
68+
y: bpd.DataFrame,
69+
type: Literal["EUCLIDEAN", "MANHATTAN", "COSINE"],
70+
name: str,
71+
) -> bpd.DataFrame:
72+
"""Calculate ML.DISTANCE from DataFrame inputs.
73+
74+
Args:
75+
x:
76+
input DataFrame
77+
y:
78+
input DataFrame
79+
type:
80+
Distance types, accept values are "EUCLIDEAN", "MANHATTAN", "COSINE".
81+
name:
82+
name of the output result column
83+
"""
84+
assert len(x.columns) == 1 and len(y.columns) == 1
85+
86+
input_data = x._cached().join(y._cached(), how="outer")
87+
x_column_id, y_column_id = x._block.value_columns[0], y._block.value_columns[0]
88+
89+
return self._apply_sql(
90+
input_data,
91+
lambda source_df: self._base_sql_generator.ml_distance(
92+
x_column_id,
93+
y_column_id,
94+
type=type,
95+
source_df=source_df,
96+
name=name,
97+
),
98+
)
99+
100+
101+
class BqmlModel(BaseBqml):
102+
"""Represents an existing BQML model in BigQuery.
103+
104+
Wraps the BQML API and SQL interface to expose the functionality needed for
105+
BigQuery DataFrames ML.
106+
"""
107+
108+
def __init__(self, session: bigframes.Session, model: bigquery.Model):
109+
self._session = session
110+
self._model = model
111+
self._model_manipulation_sql_generator = ml_sql.ModelManipulationSqlGenerator(
112+
self.model_name
113+
)
114+
115+
@property
116+
def session(self) -> bigframes.Session:
117+
"""Get the BigQuery DataFrames session that this BQML model wrapper is tied to"""
118+
return self._session
119+
120+
@property
121+
def model_name(self) -> str:
122+
"""Get the fully qualified name of the model, i.e. project_id.dataset_id.model_id"""
123+
return f"{self._model.project}.{self._model.dataset_id}.{self._model.model_id}"
124+
125+
@property
126+
def model(self) -> bigquery.Model:
127+
"""Get the BQML model associated with this wrapper"""
128+
return self._model
129+
87130
def predict(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
88131
# TODO: validate input data schema
89132
return self._apply_sql(

bigframes/ml/metrics/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from bigframes.ml.metrics import pairwise
16+
from bigframes.ml.metrics._metrics import (
17+
accuracy_score,
18+
auc,
19+
confusion_matrix,
20+
f1_score,
21+
precision_score,
22+
r2_score,
23+
recall_score,
24+
roc_auc_score,
25+
roc_curve,
26+
)
27+
28+
__all__ = [
29+
"r2_score",
30+
"recall_score",
31+
"accuracy_score",
32+
"roc_curve",
33+
"roc_auc_score",
34+
"auc",
35+
"confusion_matrix",
36+
"precision_score",
37+
"f1_score",
38+
"pairwise",
39+
]
File renamed without changes.

bigframes/ml/metrics/pairwise.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Union
17+
18+
from bigframes.ml import core, utils
19+
import bigframes.pandas as bpd
20+
import third_party.bigframes_vendored.sklearn.metrics.pairwise as vendored_metrics_pairwise
21+
22+
23+
def cosine_similarity(
24+
X: Union[bpd.DataFrame, bpd.Series], Y: Union[bpd.DataFrame, bpd.Series]
25+
) -> bpd.DataFrame:
26+
X, Y = utils.convert_to_dataframe(X, Y)
27+
if len(X.columns) != 1 or len(Y.columns) != 1:
28+
raise ValueError("Inputs X and Y can only contain 1 column.")
29+
30+
base_bqml = core.BaseBqml(session=X._session)
31+
return base_bqml.distance(X, Y, type="COSINE", name="cosine_similarity")
32+
33+
34+
cosine_similarity.__doc__ = inspect.getdoc(vendored_metrics_pairwise.cosine_similarity)

bigframes/ml/sql.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Generates SQL queries needed for BigQuery DataFrames ML
1717
"""
1818

19-
from typing import Iterable, Mapping, Optional, Union
19+
from typing import Iterable, Literal, Mapping, Optional, Union
2020

2121
import google.cloud.bigquery
2222

@@ -133,6 +133,19 @@ def ml_label_encoder(
133133
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-label-encoder for params."""
134134
return f"""ML.LABEL_ENCODER({numeric_expr_sql}, {top_k}, {frequency_threshold}) OVER() AS {name}"""
135135

136+
def ml_distance(
137+
self,
138+
col_x: str,
139+
col_y: str,
140+
type: Literal["EUCLIDEAN", "MANHATTAN", "COSINE"],
141+
source_df: bpd.DataFrame,
142+
name: str,
143+
) -> str:
144+
"""Encode ML.DISTANCE for BQML.
145+
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-distance"""
146+
source_sql, _, _ = source_df._to_sql_query(include_index=True)
147+
return f"""SELECT *, ML.DISTANCE({col_x}, {col_y}, '{type}') AS {name} FROM ({source_sql})"""
148+
136149

137150
class ModelCreationSqlGenerator(BaseSqlGenerator):
138151
"""Sql generator for creating a model entity. Model id is the standalone id without project id and dataset id."""

docs/templates/toc.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@
116116
- name: metrics
117117
uid: bigframes.ml.metrics
118118
name: metrics
119+
- items:
120+
- name: metrics.pairwise
121+
uid: bigframes.ml.metrics.pairwise
122+
name: metrics.pairwise
119123
- items:
120124
- name: model_selection
121125
uid: bigframes.ml.model_selection
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import pandas as pd
17+
18+
from bigframes.ml import metrics
19+
import bigframes.pandas as bpd
20+
21+
22+
def test_cosine_similarity():
23+
x_col = [np.array([4.1, 0.5, 1.0])]
24+
y_col = [np.array([3.0, 0.0, 2.5])]
25+
X = bpd.read_pandas(pd.DataFrame({"X": x_col}))
26+
Y = bpd.read_pandas(pd.DataFrame({"Y": y_col}))
27+
28+
result = metrics.pairwise.cosine_similarity(X, Y)
29+
expected_pd_df = pd.DataFrame(
30+
{"X": x_col, "Y": y_col, "cosine_similarity": [0.108199]}
31+
)
32+
33+
pd.testing.assert_frame_equal(
34+
result.to_pandas(), expected_pd_df, check_dtype=False, check_index_type=False
35+
)

0 commit comments

Comments
 (0)