Skip to content

Commit d54dc4a

Browse files
authored
Support string column identifiers for sort/aggregate/window and stricter Expr validation (apache#1221)
* refactor: improve DataFrame expression handling, type checking, and docs - Refactor expression handling and `_simplify_expression` for stronger type checking and clearer error handling - Improve type annotations for `file_sort_order` and `order_by` to support string inputs - Refactor DataFrame `filter` method to better validate predicates - Replace internal error message variable with public constant - Clarify usage of `col()` and `column()` in DataFrame examples * refactor: unify expression and sorting logic; improve docs and error handling - Update `order_by` handling in Window class for better type support - Improve type checking in DataFrame expression handling - Replace `Expr`/`SortExpr` with `SortKey` in file_sort_order and related functions - Simplify file_sort_order handling in SessionContext - Rename `_EXPR_TYPE_ERROR` → `EXPR_TYPE_ERROR` for consistency - Clarify usage of `col()` vs `column()` in DataFrame examples - Enhance documentation for file_sort_order in SessionContext * feat: add ensure_expr helper for validation; refine expression handling, sorting, and docs - Introduce `ensure_expr` helper and improve internal expression validation - Update error messages and tests to consistently use `EXPR_TYPE_ERROR` - Refactor expression handling with `_to_raw_expr`, `_ensure_expr`, and `SortKey` - Improve type safety and consistency in sort key definitions and file sort order - Add parameterized parquet sorting tests - Enhance DataFrame docstrings with clearer guidance and usage examples - Fix minor typos and error message clarity * Refactor and enhance expression handling, test coverage, and documentation - Introduced `ensure_expr_list` to validate and flatten nested expressions, treating strings as atomic - Updated expression utilities to improve consistency across aggregation and window functions - Consolidated and expanded parameterized tests for string equivalence in ranking and window functions - Exposed `EXPR_TYPE_ERROR` for consistent error messaging across modules and tests - Improved internal sort logic using `expr_internal.SortExpr` - Clarified expectations for `join_on` expressions in documentation - Standardized imports and improved test clarity for maintainability * refactor: update docstring for sort_or_default function to clarify its purpose * fix Ruff errors * refactor: update type hints to use typing.Union for better clarity and consistency * fix Ruff errors * refactor: simplify type hints by removing unnecessary imports for type checking * refactor: update type hints for rex_type and types methods to improve clarity * refactor: remove unnecessary type ignore comments from rex_type and types methods * docs: update section title for clarity on DataFrame method arguments * docs: clarify description of DataFrame methods accepting column names * docs: add note to clarify function documentation reference for DataFrame methods * docs: remove outdated information about predicate acceptance in DataFrame class * refactor: simplify type hint for expr_list parameter in expr_list_to_raw_expr_list function * docs: clarify usage of datafusion.col and datafusion.lit in DataFrame method documentation * docs: clarify usage of col() and lit() in DataFrame filter examples * Fix ruff errors
1 parent 9e97636 commit d54dc4a

File tree

7 files changed

+633
-123
lines changed

7 files changed

+633
-123
lines changed

docs/source/user-guide/dataframe/index.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,56 @@ DataFusion's DataFrame API offers a wide range of operations:
126126
# Drop columns
127127
df = df.drop("temporary_column")
128128
129+
Column Names as Function Arguments
130+
----------------------------------
131+
132+
Some ``DataFrame`` methods accept column names when an argument refers to an
133+
existing column. These include:
134+
135+
* :py:meth:`~datafusion.DataFrame.select`
136+
* :py:meth:`~datafusion.DataFrame.sort`
137+
* :py:meth:`~datafusion.DataFrame.drop`
138+
* :py:meth:`~datafusion.DataFrame.join` (``on`` argument)
139+
* :py:meth:`~datafusion.DataFrame.aggregate` (grouping columns)
140+
141+
See the full function documentation for details on any specific function.
142+
143+
Note that :py:meth:`~datafusion.DataFrame.join_on` expects ``col()``/``column()`` expressions rather than plain strings.
144+
145+
For such methods, you can pass column names directly:
146+
147+
.. code-block:: python
148+
149+
from datafusion import col, functions as f
150+
151+
df.sort('id')
152+
df.aggregate('id', [f.count(col('value'))])
153+
154+
The same operation can also be written with explicit column expressions, using either ``col()`` or ``column()``:
155+
156+
.. code-block:: python
157+
158+
from datafusion import col, column, functions as f
159+
160+
df.sort(col('id'))
161+
df.aggregate(column('id'), [f.count(col('value'))])
162+
163+
Note that ``column()`` is an alias of ``col()``, so you can use either name; the example above shows both in action.
164+
165+
Whenever an argument represents an expression—such as in
166+
:py:meth:`~datafusion.DataFrame.filter` or
167+
:py:meth:`~datafusion.DataFrame.with_column`—use ``col()`` to reference
168+
columns. The comparison and arithmetic operators on ``Expr`` will automatically
169+
convert any non-``Expr`` value into a literal expression, so writing
170+
171+
.. code-block:: python
172+
173+
from datafusion import col
174+
df.filter(col("age") > 21)
175+
176+
is equivalent to using ``lit(21)`` explicitly. Use ``lit()`` (also available
177+
as ``literal()``) when you need to construct a literal expression directly.
178+
129179
Terminal Operations
130180
-------------------
131181

python/datafusion/context.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,31 @@
2222
import warnings
2323
from typing import TYPE_CHECKING, Any, Protocol
2424

25-
import pyarrow as pa
26-
2725
try:
2826
from warnings import deprecated # Python 3.13+
2927
except ImportError:
3028
from typing_extensions import deprecated # Python 3.12
3129

30+
import pyarrow as pa
31+
3232
from datafusion.catalog import Catalog, CatalogProvider, Table
3333
from datafusion.dataframe import DataFrame
34-
from datafusion.expr import Expr, SortExpr, sort_list_to_raw_sort_list
34+
from datafusion.expr import SortKey, sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
3636
from datafusion.user_defined import AggregateUDF, ScalarUDF, TableFunction, WindowUDF
3737

3838
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
3939
from ._internal import SessionConfig as SessionConfigInternal
4040
from ._internal import SessionContext as SessionContextInternal
4141
from ._internal import SQLOptions as SQLOptionsInternal
42+
from ._internal import expr as expr_internal
4243

4344
if TYPE_CHECKING:
4445
import pathlib
46+
from collections.abc import Sequence
4547

4648
import pandas as pd
47-
import polars as pl
49+
import polars as pl # type: ignore[import]
4850

4951
from datafusion.plan import ExecutionPlan, LogicalPlan
5052

@@ -553,7 +555,7 @@ def register_listing_table(
553555
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
554556
file_extension: str = ".parquet",
555557
schema: pa.Schema | None = None,
556-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
558+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
557559
) -> None:
558560
"""Register multiple files as a single table.
559561
@@ -567,23 +569,20 @@ def register_listing_table(
567569
table_partition_cols: Partition columns.
568570
file_extension: File extension of the provided table.
569571
schema: The data source schema.
570-
file_sort_order: Sort order for the file.
572+
file_sort_order: Sort order for the file. Each sort key can be
573+
specified as a column name (``str``), an expression
574+
(``Expr``), or a ``SortExpr``.
571575
"""
572576
if table_partition_cols is None:
573577
table_partition_cols = []
574578
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
575-
file_sort_order_raw = (
576-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
577-
if file_sort_order is not None
578-
else None
579-
)
580579
self.ctx.register_listing_table(
581580
name,
582581
str(path),
583582
table_partition_cols,
584583
file_extension,
585584
schema,
586-
file_sort_order_raw,
585+
self._convert_file_sort_order(file_sort_order),
587586
)
588587

589588
def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
@@ -808,7 +807,7 @@ def register_parquet(
808807
file_extension: str = ".parquet",
809808
skip_metadata: bool = True,
810809
schema: pa.Schema | None = None,
811-
file_sort_order: list[list[SortExpr]] | None = None,
810+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
812811
) -> None:
813812
"""Register a Parquet file as a table.
814813
@@ -827,7 +826,9 @@ def register_parquet(
827826
that may be in the file schema. This can help avoid schema
828827
conflicts due to metadata.
829828
schema: The data source schema.
830-
file_sort_order: Sort order for the file.
829+
file_sort_order: Sort order for the file. Each sort key can be
830+
specified as a column name (``str``), an expression
831+
(``Expr``), or a ``SortExpr``.
831832
"""
832833
if table_partition_cols is None:
833834
table_partition_cols = []
@@ -840,9 +841,7 @@ def register_parquet(
840841
file_extension,
841842
skip_metadata,
842843
schema,
843-
[sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
844-
if file_sort_order is not None
845-
else None,
844+
self._convert_file_sort_order(file_sort_order),
846845
)
847846

848847
def register_csv(
@@ -1099,7 +1098,7 @@ def read_parquet(
10991098
file_extension: str = ".parquet",
11001099
skip_metadata: bool = True,
11011100
schema: pa.Schema | None = None,
1102-
file_sort_order: list[list[Expr | SortExpr]] | None = None,
1101+
file_sort_order: Sequence[Sequence[SortKey]] | None = None,
11031102
) -> DataFrame:
11041103
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
11051104
@@ -1116,19 +1115,17 @@ def read_parquet(
11161115
schema: An optional schema representing the parquet files. If None,
11171116
the parquet reader will try to infer it based on data in the
11181117
file.
1119-
file_sort_order: Sort order for the file.
1118+
file_sort_order: Sort order for the file. Each sort key can be
1119+
specified as a column name (``str``), an expression
1120+
(``Expr``), or a ``SortExpr``.
11201121
11211122
Returns:
11221123
DataFrame representation of the read Parquet files
11231124
"""
11241125
if table_partition_cols is None:
11251126
table_partition_cols = []
11261127
table_partition_cols = self._convert_table_partition_cols(table_partition_cols)
1127-
file_sort_order = (
1128-
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1129-
if file_sort_order is not None
1130-
else None
1131-
)
1128+
file_sort_order = self._convert_file_sort_order(file_sort_order)
11321129
return DataFrame(
11331130
self.ctx.read_parquet(
11341131
str(path),
@@ -1179,6 +1176,24 @@ def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:
11791176
"""Execute the ``plan`` and return the results."""
11801177
return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
11811178

1179+
@staticmethod
1180+
def _convert_file_sort_order(
1181+
file_sort_order: Sequence[Sequence[SortKey]] | None,
1182+
) -> list[list[expr_internal.SortExpr]] | None:
1183+
"""Convert nested ``SortKey`` sequences into raw sort expressions.
1184+
1185+
Each ``SortKey`` can be a column name string, an ``Expr``, or a
1186+
``SortExpr`` and will be converted using
1187+
:func:`datafusion.expr.sort_list_to_raw_sort_list`.
1188+
"""
1189+
# Convert each ``SortKey`` in the provided sort order to the low-level
1190+
# representation expected by the Rust bindings.
1191+
return (
1192+
[sort_list_to_raw_sort_list(f) for f in file_sort_order]
1193+
if file_sort_order is not None
1194+
else None
1195+
)
1196+
11821197
@staticmethod
11831198
def _convert_table_partition_cols(
11841199
table_partition_cols: list[tuple[str, str | pa.DataType]],

0 commit comments

Comments
 (0)