Skip to content

Commit 7e5b6a8

Browse files
authored
feat: add Linear_Regression.global_explain() (#1446)
* feat: add Linear_Regression.global_explain() * remove class_level_explain param * working global_explain() * begin adding tests * update snippet * complete snippet * failing, near complete linear model test * passing system test * Update core.py - set index to have sorted by feature * Update test_linear_model.py - remove set/set index * Update linear_model.py - fix doc section * Update conftest.py - rename penguins w global explain * Update linear_model.py - complete doc * lint * passing test and fixed expected results
1 parent ac59173 commit 7e5b6a8

File tree

6 files changed

+106
-0
lines changed

6 files changed

+106
-0
lines changed

bigframes/ml/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def explain_predict(
134134
),
135135
)
136136

137+
def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
138+
sql = self._model_manipulation_sql_generator.ml_global_explain(
139+
struct_options=options
140+
)
141+
return (
142+
self._session.read_gbq(sql)
143+
.sort_values(by="attribution", ascending=False)
144+
.set_index("feature")
145+
)
146+
137147
def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
138148
return self._apply_ml_tvf(
139149
input_data,

bigframes/ml/linear_model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ def predict_explain(
203203
X, options={"top_k_features": top_k_features}
204204
)
205205

206+
def global_explain(
207+
self,
208+
) -> bpd.DataFrame:
209+
"""
210+
Provide explanations for an entire linear regression model.
211+
212+
.. note::
213+
Output matches that of the BigQuery ML.GLOBAL_EXPLAIN function.
214+
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain
215+
216+
Returns:
217+
bigframes.pandas.DataFrame:
218+
Dataframes containing feature importance values and corresponding attributions, designed to provide a global explanation of feature influence.
219+
"""
220+
221+
if not self._bqml_model:
222+
raise RuntimeError("A model must be fitted before predict")
223+
224+
return self._bqml_model.global_explain({})
225+
206226
def score(
207227
self,
208228
X: utils.ArrayType,

bigframes/ml/sql.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def ml_explain_predict(
312312
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
313313
({source_sql}), {struct_options_sql})"""
314314

315+
def ml_global_explain(self, struct_options) -> str:
316+
"""Encode ML.GLOBAL_EXPLAIN for BQML"""
317+
struct_options_sql = self.struct_options(**struct_options)
318+
return f"""SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {self._model_ref_sql()},
319+
{struct_options_sql})"""
320+
315321
def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
316322
"""Encode ML.FORECAST for BQML"""
317323
struct_options_sql = self.struct_options(**struct_options)

samples/snippets/linear_regression_tutorial_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,31 @@ def test_linear_regression(random_model_id: str) -> None:
9292
# 3 5349.603734 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 5349.603734 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 15.6 221.0 5000.0 MALE
9393
# 4 4637.165037 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 4637.165037 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 13.2 211.0 4500.0 FEMALE
9494
# [END bigquery_dataframes_bqml_linear_predict_explain]
95+
# [START bigquery_dataframes_bqml_linear_global_explain]
96+
# To use the `global_explain()` function, the model must be recreated with `enable_global_explain` set to `True`.
97+
model = LinearRegression(enable_global_explain=True)
98+
99+
# The model must the be fitted before it can be saved to BigQuery and then explained.
100+
training_data = bq_df.dropna(subset=["body_mass_g"])
101+
X = training_data.drop(columns=["body_mass_g"])
102+
y = training_data[["body_mass_g"]]
103+
model.fit(X, y)
104+
model.to_gbq("bqml_tutorial.penguins_model", replace=True)
105+
106+
# Explain the model
107+
explain_model = model.global_explain()
108+
109+
# Expected results:
110+
# attribution
111+
# feature
112+
# island 5737.315921
113+
# species 4073.280549
114+
# sex 622.070896
115+
# flipper_length_mm 193.612051
116+
# culmen_depth_mm 117.084944
117+
# culmen_length_mm 94.366793
118+
# [END bigquery_dataframes_bqml_linear_global_explain]
119+
assert explain_model is not None
95120
assert feature_columns is not None
96121
assert label_columns is not None
97122
assert model is not None

tests/system/small/ml/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ def ephemera_penguins_linear_model(
8484
return bf_model
8585

8686

87+
@pytest.fixture(scope="function")
88+
def penguins_linear_model_w_global_explain(
89+
penguins_bqml_linear_model: core.BqmlModel,
90+
) -> linear_model.LinearRegression:
91+
bf_model = linear_model.LinearRegression(enable_global_explain=True)
92+
bf_model._bqml_model = penguins_bqml_linear_model
93+
return bf_model
94+
95+
8796
@pytest.fixture(scope="session")
8897
def penguins_logistic_model(
8998
session, penguins_logistic_model_name

tests/system/small/ml/test_linear_model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,42 @@ def test_to_gbq_saved_linear_reg_model_scores(
228228
)
229229

230230

231+
def test_linear_reg_model_global_explain(
232+
penguins_linear_model_w_global_explain, new_penguins_df
233+
):
234+
training_data = new_penguins_df.dropna(subset=["body_mass_g"])
235+
X = training_data.drop(columns=["body_mass_g"])
236+
y = training_data[["body_mass_g"]]
237+
penguins_linear_model_w_global_explain.fit(X, y)
238+
global_ex = penguins_linear_model_w_global_explain.global_explain()
239+
assert global_ex.shape == (6, 1)
240+
expected_columns = pandas.Index(["attribution"])
241+
pandas.testing.assert_index_equal(global_ex.columns, expected_columns)
242+
result = global_ex.to_pandas().drop(["attribution"], axis=1).sort_index()
243+
expected_feature = (
244+
pandas.DataFrame(
245+
{
246+
"feature": [
247+
"island",
248+
"species",
249+
"sex",
250+
"flipper_length_mm",
251+
"culmen_depth_mm",
252+
"culmen_length_mm",
253+
]
254+
},
255+
)
256+
.set_index("feature")
257+
.sort_index()
258+
)
259+
pandas.testing.assert_frame_equal(
260+
result,
261+
expected_feature,
262+
check_exact=False,
263+
check_index_type=False,
264+
)
265+
266+
231267
def test_to_gbq_replace(penguins_linear_model, table_id_unique):
232268
penguins_linear_model.to_gbq(table_id_unique, replace=True)
233269
with pytest.raises(google.api_core.exceptions.Conflict):

0 commit comments

Comments
 (0)