Skip to content

Commit c14cdbb

Browse files
committed
fix(duckdb): implement IGNORE NULLS for FirstValue and LastValue
1 parent d441860 commit c14cdbb

File tree

6 files changed

+87
-10
lines changed

6 files changed

+87
-10
lines changed

ibis/backends/sql/compilers/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import ibis.expr.datatypes as dt
1919
import ibis.expr.operations as ops
2020
from ibis.backends.sql.rewrites import (
21-
FirstValue,
22-
LastValue,
2321
add_one_to_nth_value_input,
2422
add_order_by_to_empty_ranking_window_functions,
2523
empty_in_values_right_side,
@@ -337,13 +335,11 @@ class SQLGlotCompiler(abc.ABC):
337335
ops.Degrees: "degrees",
338336
ops.DenseRank: "dense_rank",
339337
ops.Exp: "exp",
340-
FirstValue: "first_value",
341338
ops.GroupConcat: "group_concat",
342339
ops.IfElse: "if",
343340
ops.IsInf: "isinf",
344341
ops.IsNan: "isnan",
345342
ops.JSONGetItem: "json_extract",
346-
LastValue: "last_value",
347343
ops.Levenshtein: "levenshtein",
348344
ops.Ln: "ln",
349345
ops.Log10: "log",
@@ -1244,6 +1240,12 @@ def visit_RowID(self, op, *, table):
12441240
op.name, table=table.alias_or_name, quoted=self.quoted, copy=False
12451241
)
12461242

1243+
def visit_FirstLastValue(self, op, *, arg, include_null):
1244+
fun_name = "first_value" if type(op).__name__ == "FirstValue" else "last_value"
1245+
return self.f[fun_name](arg)
1246+
1247+
visit_FirstValue = visit_LastValue = visit_FirstLastValue
1248+
12471249
# TODO(kszucs): this should be renamed to something UDF related
12481250
def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
12491251
# for builtin functions use the exact function name, otherwise use the

ibis/backends/sql/compilers/duckdb.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,5 +712,19 @@ def visit_TableUnnest(
712712
def visit_StringToTime(self, op, *, arg, format_str):
713713
return self.cast(self.f.str_to_time(arg, format_str), to=dt.time)
714714

715+
def visit_LastValue(self, op, *, arg, include_null):
716+
return (
717+
self.f.last_value(arg)
718+
if include_null
719+
else sge.IgnoreNulls(this=self.f.last_value(arg))
720+
)
721+
722+
def visit_FirstValue(self, op, *, arg, include_null):
723+
return (
724+
self.f.first_value(arg)
725+
if include_null
726+
else sge.IgnoreNulls(this=self.f.first_value(arg))
727+
)
728+
715729

716730
compiler = DuckDBCompiler()

ibis/backends/sql/compilers/pyspark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ def visit_CountDistinctStar(self, op, *, arg, where):
243243
]
244244
return self.f.count(sge.Distinct(expressions=cols))
245245

246-
def visit_FirstValue(self, op, *, arg):
246+
def visit_FirstValue(self, op, *, arg, include_null):
247247
return sge.IgnoreNulls(this=self.f.first(arg))
248248

249-
def visit_LastValue(self, op, *, arg):
249+
def visit_LastValue(self, op, *, arg, include_null):
250250
return sge.IgnoreNulls(this=self.f.last(arg))
251251

252252
def visit_First(self, op, *, arg, where, order_by, include_null):

ibis/backends/sql/rewrites.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
from collections.abc import Mapping
88
from functools import reduce
9-
from typing import TYPE_CHECKING, Any
9+
from typing import TYPE_CHECKING, Any, Optional
1010

1111
import toolz
1212
from public import public
@@ -73,6 +73,7 @@ class FirstValue(ops.Analytic):
7373
"""Retrieve the first element."""
7474

7575
arg: ops.Column[dt.Any]
76+
include_null: Optional[bool] = False
7677

7778
@attribute
7879
def dtype(self):
@@ -84,6 +85,7 @@ class LastValue(ops.Analytic):
8485
"""Retrieve the last element."""
8586

8687
arg: ops.Column[dt.Any]
88+
include_null: bool = False
8789

8890
@attribute
8991
def dtype(self):
@@ -204,7 +206,7 @@ def first_to_firstvalue(_, **kwargs):
204206
"in a window function"
205207
)
206208
klass = FirstValue if isinstance(_.func, ops.First) else LastValue
207-
return _.copy(func=klass(_.func.arg))
209+
return _.copy(func=klass(_.func.arg, include_null=_.func.include_null))
208210

209211

210212
@replace(p.Alias)

ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ WITH "t5" AS (
1515
"t2"."field_of_study",
1616
"t2"."years",
1717
"t2"."degrees",
18-
FIRST_VALUE("t2"."degrees") OVER (
18+
FIRST_VALUE("t2"."degrees" IGNORE NULLS) OVER (
1919
PARTITION BY "t2"."field_of_study"
2020
ORDER BY "t2"."years" ASC
2121
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
2222
) AS "earliest_degrees",
23-
LAST_VALUE("t2"."degrees") OVER (
23+
LAST_VALUE("t2"."degrees" IGNORE NULLS) OVER (
2424
PARTITION BY "t2"."field_of_study"
2525
ORDER BY "t2"."years" ASC
2626
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING

ibis/backends/tests/test_window.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,65 @@ def test_first_last(backend):
11091109
backend.assert_frame_equal(result, expected)
11101110

11111111

1112+
@pytest.mark.notimpl(
1113+
[
1114+
"risingwave",
1115+
"clickhouse",
1116+
"bigquery",
1117+
"oracle",
1118+
"snowflake",
1119+
"databricks",
1120+
"pyspark",
1121+
],
1122+
raises=AssertionError,
1123+
)
1124+
@pytest.mark.notyet(
1125+
["polars"],
1126+
raises=com.OperationNotDefinedError,
1127+
)
1128+
@pytest.mark.notyet(
1129+
["flink"],
1130+
raises=NotImplementedError,
1131+
)
1132+
@pytest.mark.notyet(
1133+
[
1134+
"mysql",
1135+
"sqlite",
1136+
"postgres",
1137+
"datafusion",
1138+
"druid",
1139+
"athena",
1140+
"impala",
1141+
"mssql",
1142+
"trino",
1143+
"exasol",
1144+
],
1145+
raises=Exception,
1146+
)
1147+
def test_first_last_include_nulls(backend):
1148+
t = ibis.memtable({"a": (2, 2, 1, 1), "b": (None, 3, 5, None), "c": list(range(4))})
1149+
w = ibis.window(group_by=t.a, order_by=t.c)
1150+
expr = t.select(
1151+
"a",
1152+
b_first_null=t.b.first(include_null=True).over(w),
1153+
b_last_null=t.b.last(include_null=True).over(w),
1154+
b_first=t.b.first(include_null=False).over(w),
1155+
b_last=t.b.last(include_null=False).over(w),
1156+
)
1157+
con = backend.connection
1158+
# execute the expr, and ensure the columns are sorted by column "a"
1159+
result = con.execute(expr).sort_values("a").set_index("a").reset_index(drop=True)
1160+
expected = pd.DataFrame(
1161+
{
1162+
"b_first_null": [5, 5, None, None],
1163+
"b_last_null": [None, None, 3, 3],
1164+
"b_first": [5, 5, 3, 3],
1165+
"b_last": [5, 5, 3, 3],
1166+
}
1167+
)
1168+
backend.assert_frame_equal(result, expected, check_dtype=False)
1169+
1170+
11121171
@pytest.mark.notyet(
11131172
["bigquery"], raises=GoogleBadRequest, reason="not supported by BigQuery"
11141173
)

0 commit comments

Comments
 (0)