Skip to content

Commit 8af26d0

Browse files
authored
feat: implement ai.classify() (#1781)
* feat: implement ai.classify() * check label duplicity
1 parent 7269512 commit 8af26d0

File tree

3 files changed

+182
-1
lines changed

3 files changed

+182
-1
lines changed

bigframes/operations/ai.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import re
1818
import typing
19-
from typing import Dict, List, Optional
19+
from typing import Dict, List, Optional, Sequence
2020
import warnings
2121

2222
import numpy as np
@@ -258,6 +258,101 @@ def extract_logprob(s: bigframes.series.Series) -> bigframes.series.Series:
258258

259259
return concat([self._df, *attach_columns], axis=1)
260260

261+
def classify(
262+
self,
263+
instruction: str,
264+
model,
265+
labels: Sequence[str],
266+
output_column: str = "result",
267+
ground_with_google_search: bool = False,
268+
attach_logprobs=False,
269+
):
270+
"""
271+
Classifies the rows of dataframes based on user instruction into the provided labels.
272+
273+
**Examples:**
274+
275+
>>> import bigframes.pandas as bpd
276+
>>> bpd.options.display.progress_bar = None
277+
>>> bpd.options.experiments.ai_operators = True
278+
>>> bpd.options.compute.ai_ops_confirmation_threshold = 25
279+
280+
>>> import bigframes.ml.llm as llm
281+
>>> model = llm.GeminiTextGenerator(model_name="gemini-2.0-flash-001")
282+
283+
>>> df = bpd.DataFrame({
284+
... "feedback_text": [
285+
... "The product is amazing, but the shipping was slow.",
286+
... "I had an issue with my recent bill.",
287+
... "The user interface is very intuitive."
288+
... ],
289+
... })
290+
>>> df.ai.classify("{feedback_text}", model=model, labels=["Shipping", "Billing", "UI"])
291+
feedback_text result
292+
0 The product is amazing, but the shipping was s... Shipping
293+
1 I had an issue with my recent bill. Billing
294+
2 The user interface is very intuitive. UI
295+
<BLANKLINE>
296+
[3 rows x 2 columns]
297+
298+
Args:
299+
instruction (str):
300+
An instruction on how to classify the data. This value must contain
301+
column references by name, which should be wrapped in a pair of braces.
302+
For example, if you have a column "feedback", you can refer to this column
303+
with"{food}".
304+
305+
model (bigframes.ml.llm.GeminiTextGenerator):
306+
A GeminiTextGenerator provided by Bigframes ML package.
307+
308+
labels (Sequence[str]):
309+
A collection of labels (categories). It must contain at least two and at most 20 elements.
310+
Labels are case sensitive. Duplicated labels are not allowed.
311+
312+
output_column (str, default "result"):
313+
The name of column for the output.
314+
315+
ground_with_google_search (bool, default False):
316+
Enables Grounding with Google Search for the GeminiTextGenerator model.
317+
When set to True, the model incorporates relevant information from Google
318+
Search results into its responses, enhancing their accuracy and factualness.
319+
Note: Using this feature may impact billing costs. Refer to the pricing
320+
page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models
321+
The default is `False`.
322+
323+
attach_logprobs (bool, default False):
324+
Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level
325+
of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0.
326+
327+
328+
Returns:
329+
bigframes.pandas.DataFrame: DataFrame with classification result.
330+
331+
Raises:
332+
NotImplementedError: when the AI operator experiment is off.
333+
ValueError: when the instruction refers to a non-existing column, when no
334+
columns are referred to, or when the count of labels does not meet the
335+
requirement.
336+
"""
337+
338+
if len(labels) < 2 or len(labels) > 20:
339+
raise ValueError(
340+
f"The number of labels should be between 2 and 20 (inclusive), but {len(labels)} labels are provided."
341+
)
342+
343+
if len(set(labels)) != len(labels):
344+
raise ValueError("There are duplicate labels.")
345+
346+
updated_instruction = f"Based on the user instruction {instruction}, you must provide an answer that must exist in the following list of labels: {labels}"
347+
348+
return self.map(
349+
updated_instruction,
350+
model,
351+
output_schema={output_column: "string"},
352+
ground_with_google_search=ground_with_google_search,
353+
attach_logprobs=attach_logprobs,
354+
)
355+
261356
def join(
262357
self,
263358
other,

tests/system/large/operations/test_ai.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,33 @@ def test_map_invalid_model_raise_error():
398398
)
399399

400400

401+
def test_classify(gemini_flash_model, session):
402+
df = dataframe.DataFrame(data={"creature": ["dog", "rose"]}, session=session)
403+
404+
with bigframes.option_context(
405+
AI_OP_EXP_OPTION,
406+
True,
407+
THRESHOLD_OPTION,
408+
10,
409+
):
410+
actual_result = df.ai.classify(
411+
"{creature}",
412+
gemini_flash_model,
413+
labels=["animal", "plant"],
414+
output_column="result",
415+
).to_pandas()
416+
417+
expected_result = pd.DataFrame(
418+
{
419+
"creature": ["dog", "rose"],
420+
"result": ["animal", "plant"],
421+
}
422+
)
423+
pandas.testing.assert_frame_equal(
424+
actual_result, expected_result, check_index_type=False, check_dtype=False
425+
)
426+
427+
401428
@pytest.mark.parametrize(
402429
"instruction",
403430
[

tests/system/small/operations/test_ai.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,65 @@ def test_map(session):
108108
)
109109

110110

111+
def test_classify(session):
112+
df = dataframe.DataFrame({"col": ["A", "B"]}, session=session)
113+
model = FakeGeminiTextGenerator(
114+
dataframe.DataFrame(
115+
{
116+
"result": ["A", "B"],
117+
"full_response": _create_dummy_full_response(2),
118+
},
119+
session=session,
120+
),
121+
)
122+
123+
with bigframes.option_context(
124+
AI_OP_EXP_OPTION,
125+
True,
126+
THRESHOLD_OPTION,
127+
50,
128+
):
129+
result = df.ai.classify(
130+
"classify {col}", model=model, labels=["A", "B"]
131+
).to_pandas()
132+
133+
pandas.testing.assert_frame_equal(
134+
result,
135+
pd.DataFrame(
136+
{"col": ["A", "B"], "result": ["A", "B"]}, dtype=dtypes.STRING_DTYPE
137+
),
138+
check_index_type=False,
139+
)
140+
141+
142+
@pytest.mark.parametrize(
143+
"labels",
144+
[
145+
pytest.param([], id="empty-label"),
146+
pytest.param(["A", "A", "B"], id="duplicate-labels"),
147+
],
148+
)
149+
def test_classify_invalid_labels_raise_error(session, labels):
150+
df = dataframe.DataFrame({"col": ["A", "B"]}, session=session)
151+
model = FakeGeminiTextGenerator(
152+
dataframe.DataFrame(
153+
{
154+
"result": ["A", "B"],
155+
"full_response": _create_dummy_full_response(2),
156+
},
157+
session=session,
158+
),
159+
)
160+
161+
with bigframes.option_context(
162+
AI_OP_EXP_OPTION,
163+
True,
164+
THRESHOLD_OPTION,
165+
50,
166+
), pytest.raises(ValueError):
167+
df.ai.classify("classify {col}", model=model, labels=labels)
168+
169+
111170
def test_join(session):
112171
left_df = dataframe.DataFrame({"col_A": ["A"]}, session=session)
113172
right_df = dataframe.DataFrame({"col_B": ["B"]}, session=session)

0 commit comments

Comments
 (0)