|
16 | 16 |
|
17 | 17 | import re
|
18 | 18 | import typing
|
19 |
| -from typing import Dict, List, Optional |
| 19 | +from typing import Dict, List, Optional, Sequence |
20 | 20 | import warnings
|
21 | 21 |
|
22 | 22 | import numpy as np
|
@@ -258,6 +258,101 @@ def extract_logprob(s: bigframes.series.Series) -> bigframes.series.Series:
|
258 | 258 |
|
259 | 259 | return concat([self._df, *attach_columns], axis=1)
|
260 | 260 |
|
| 261 | + def classify( |
| 262 | + self, |
| 263 | + instruction: str, |
| 264 | + model, |
| 265 | + labels: Sequence[str], |
| 266 | + output_column: str = "result", |
| 267 | + ground_with_google_search: bool = False, |
| 268 | + attach_logprobs=False, |
| 269 | + ): |
| 270 | + """ |
| 271 | + Classifies the rows of dataframes based on user instruction into the provided labels. |
| 272 | +
|
| 273 | + **Examples:** |
| 274 | +
|
| 275 | + >>> import bigframes.pandas as bpd |
| 276 | + >>> bpd.options.display.progress_bar = None |
| 277 | + >>> bpd.options.experiments.ai_operators = True |
| 278 | + >>> bpd.options.compute.ai_ops_confirmation_threshold = 25 |
| 279 | +
|
| 280 | + >>> import bigframes.ml.llm as llm |
| 281 | + >>> model = llm.GeminiTextGenerator(model_name="gemini-2.0-flash-001") |
| 282 | +
|
| 283 | + >>> df = bpd.DataFrame({ |
| 284 | + ... "feedback_text": [ |
| 285 | + ... "The product is amazing, but the shipping was slow.", |
| 286 | + ... "I had an issue with my recent bill.", |
| 287 | + ... "The user interface is very intuitive." |
| 288 | + ... ], |
| 289 | + ... }) |
| 290 | + >>> df.ai.classify("{feedback_text}", model=model, labels=["Shipping", "Billing", "UI"]) |
| 291 | + feedback_text result |
| 292 | + 0 The product is amazing, but the shipping was s... Shipping |
| 293 | + 1 I had an issue with my recent bill. Billing |
| 294 | + 2 The user interface is very intuitive. UI |
| 295 | + <BLANKLINE> |
| 296 | + [3 rows x 2 columns] |
| 297 | +
|
| 298 | + Args: |
| 299 | + instruction (str): |
| 300 | + An instruction on how to classify the data. This value must contain |
| 301 | + column references by name, which should be wrapped in a pair of braces. |
| 302 | + For example, if you have a column "feedback", you can refer to this column |
| 303 | + with"{food}". |
| 304 | +
|
| 305 | + model (bigframes.ml.llm.GeminiTextGenerator): |
| 306 | + A GeminiTextGenerator provided by Bigframes ML package. |
| 307 | +
|
| 308 | + labels (Sequence[str]): |
| 309 | + A collection of labels (categories). It must contain at least two and at most 20 elements. |
| 310 | + Labels are case sensitive. Duplicated labels are not allowed. |
| 311 | +
|
| 312 | + output_column (str, default "result"): |
| 313 | + The name of column for the output. |
| 314 | +
|
| 315 | + ground_with_google_search (bool, default False): |
| 316 | + Enables Grounding with Google Search for the GeminiTextGenerator model. |
| 317 | + When set to True, the model incorporates relevant information from Google |
| 318 | + Search results into its responses, enhancing their accuracy and factualness. |
| 319 | + Note: Using this feature may impact billing costs. Refer to the pricing |
| 320 | + page for details: https://cloud.google.com/vertex-ai/generative-ai/pricing#google_models |
| 321 | + The default is `False`. |
| 322 | +
|
| 323 | + attach_logprobs (bool, default False): |
| 324 | + Controls whether to attach an additional "logprob" column for each result. Logprobs are float-point values reflecting the confidence level |
| 325 | + of the LLM for their responses. Higher values indicate more confidence. The value is in the range between negative infinite and 0. |
| 326 | +
|
| 327 | +
|
| 328 | + Returns: |
| 329 | + bigframes.pandas.DataFrame: DataFrame with classification result. |
| 330 | +
|
| 331 | + Raises: |
| 332 | + NotImplementedError: when the AI operator experiment is off. |
| 333 | + ValueError: when the instruction refers to a non-existing column, when no |
| 334 | + columns are referred to, or when the count of labels does not meet the |
| 335 | + requirement. |
| 336 | + """ |
| 337 | + |
| 338 | + if len(labels) < 2 or len(labels) > 20: |
| 339 | + raise ValueError( |
| 340 | + f"The number of labels should be between 2 and 20 (inclusive), but {len(labels)} labels are provided." |
| 341 | + ) |
| 342 | + |
| 343 | + if len(set(labels)) != len(labels): |
| 344 | + raise ValueError("There are duplicate labels.") |
| 345 | + |
| 346 | + updated_instruction = f"Based on the user instruction {instruction}, you must provide an answer that must exist in the following list of labels: {labels}" |
| 347 | + |
| 348 | + return self.map( |
| 349 | + updated_instruction, |
| 350 | + model, |
| 351 | + output_schema={output_column: "string"}, |
| 352 | + ground_with_google_search=ground_with_google_search, |
| 353 | + attach_logprobs=attach_logprobs, |
| 354 | + ) |
| 355 | + |
261 | 356 | def join(
|
262 | 357 | self,
|
263 | 358 | other,
|
|
0 commit comments