Skip to content

Commit

Permalink
[SPARK-48504][PYTHON][CONNECT][FOLLOW-UP] Code clean up
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Code clean up

### Why are the changes needed?
Code clean up

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46898 from zhengruifeng/win_refactor.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Jun 6, 2024
1 parent ce1b08f commit edb9236
Showing 1 changed file with 14 additions and 92 deletions.
106 changes: 14 additions & 92 deletions python/pyspark/sql/connect/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,37 @@

check_dependencies(__name__)

from typing import TYPE_CHECKING, Union, Sequence, List, Optional
from typing import TYPE_CHECKING, Union, Sequence, List, Optional, Tuple, cast, Iterable

from pyspark.sql.column import Column
from pyspark.sql.window import (
Window as ParentWindow,
WindowSpec as ParentWindowSpec,
)
from pyspark.sql.connect.expressions import (
ColumnReference,
Expression,
SortOrder,
)
from pyspark.sql.window import Window as PySparkWindow, WindowSpec as PySparkWindowSpec
from pyspark.errors import PySparkTypeError
from pyspark.sql.connect.expressions import Expression, SortOrder
from pyspark.sql.connect.functions import builtin as F

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName, ColumnOrName_

__all__ = ["Window", "WindowSpec"]


def _to_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> List[Column]:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
return [F._to_col(c) for c in cast(Iterable["ColumnOrName"], cols)]


class WindowFrame:
def __init__(self, isRowFrame: bool, start: int, end: int) -> None:
super().__init__()

assert isinstance(isRowFrame, bool)

assert isinstance(start, int)

assert isinstance(end, int)

self._isRowFrame = isRowFrame

self._start = start

self._end = end

def __repr__(self) -> str:
Expand Down Expand Up @@ -82,83 +78,23 @@ def __init__(
assert isinstance(partitionSpec, list) and all(
isinstance(p, Expression) for p in partitionSpec
)

assert isinstance(orderSpec, list) and all(isinstance(s, SortOrder) for s in orderSpec)

assert frame is None or isinstance(frame, WindowFrame)

self._partitionSpec = partitionSpec

self._orderSpec = orderSpec

self._frame = frame

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
_cols.append(col)
elif isinstance(col, list):
for c in col:
if isinstance(c, (str, Column)):
_cols.append(c)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(c).__name__},
)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(col).__name__},
)

newPartitionSpec: List[Expression] = []
for c in _cols: # type: ignore[assignment]
if isinstance(c, Column):
newPartitionSpec.append(c._expr) # type: ignore[arg-type]
else:
newPartitionSpec.append(ColumnReference(c)) # type: ignore[arg-type]

return WindowSpec(
partitionSpec=newPartitionSpec,
partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc]
orderSpec=self._orderSpec,
frame=self._frame,
)

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
_cols: List[ColumnOrName] = []
for col in cols:
if isinstance(col, (str, Column)):
_cols.append(col)
elif isinstance(col, list):
for c in col:
if isinstance(c, (str, Column)):
_cols.append(c)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(c).__name__},
)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_LIST_OR_STR",
message_parameters={"arg_name": "cols", "arg_type": type(col).__name__},
)

newOrderSpec: List[SortOrder] = []
for c in _cols: # type: ignore[assignment]
if isinstance(c, Column):
if isinstance(c._expr, SortOrder):
newOrderSpec.append(c._expr)
else:
newOrderSpec.append(SortOrder(c._expr)) # type: ignore[arg-type]
else:
newOrderSpec.append(SortOrder(ColumnReference(c))) # type: ignore[arg-type]

return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=newOrderSpec,
orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)],
frame=self._frame,
)

Expand Down Expand Up @@ -199,13 +135,6 @@ def __repr__(self) -> str:
return "WindowSpec(" + ", ".join(strs) + ")"


WindowSpec.rangeBetween.__doc__ = PySparkWindowSpec.rangeBetween.__doc__
WindowSpec.rowsBetween.__doc__ = PySparkWindowSpec.rowsBetween.__doc__
WindowSpec.orderBy.__doc__ = PySparkWindowSpec.orderBy.__doc__
WindowSpec.partitionBy.__doc__ = PySparkWindowSpec.partitionBy.__doc__
WindowSpec.__doc__ = PySparkWindowSpec.__doc__


class Window(ParentWindow):
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)

Expand All @@ -226,29 +155,22 @@ def rangeBetween(start: int, end: int) -> ParentWindowSpec:
return Window._spec.rangeBetween(start, end)


Window.orderBy.__doc__ = PySparkWindow.orderBy.__doc__
Window.rowsBetween.__doc__ = PySparkWindow.rowsBetween.__doc__
Window.rangeBetween.__doc__ = PySparkWindow.rangeBetween.__doc__
Window.partitionBy.__doc__ = PySparkWindow.partitionBy.__doc__
Window.__doc__ = PySparkWindow.__doc__


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.window
import pyspark.sql.window

globs = pyspark.sql.connect.window.__dict__.copy()
globs = pyspark.sql.window.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.window tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.window,
pyspark.sql.window,
globs=globs,
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
Expand Down

0 comments on commit edb9236

Please sign in to comment.