Skip to content

Commit

Permalink
Merge pull request kayak#376 from FranGoitia/add-filter-to-window-fun…
Browse files Browse the repository at this point in the history
…ctions

add filter clause for window functions
  • Loading branch information
twheys authored Mar 10, 2020
2 parents 9a2e71a + 863010f commit 0393aa9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
25 changes: 20 additions & 5 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,10 +1194,17 @@ class AnalyticFunction(Function):

def __init__(self, name, *args, **kwargs):
super(AnalyticFunction, self).__init__(name, *args, **kwargs)
self._filters = []
self._partition = []
self._orderbys = []
self._include_filter = False
self._include_over = False

@builder
def filter(self, *filters):
self._include_filter = True
self._filters += filters

@builder
def over(self, *terms):
self._include_over = True
Expand All @@ -1216,6 +1223,12 @@ def _orderby_field(self, field, orient, **kwargs):
field=field.get_sql(**kwargs), orient=orient.value,
)

def get_filter_sql(self):
if self._include_filter:
return "WHERE {criterions}".format(
criterions=Criterion.all(self._filters).get_sql()
)

def get_partition_sql(self, **kwargs):
terms = []
if self._partition:
Expand All @@ -1242,14 +1255,16 @@ def get_partition_sql(self, **kwargs):

def get_function_sql(self, **kwargs):
function_sql = super(AnalyticFunction, self).get_function_sql(**kwargs)
filter_sql = self.get_filter_sql()
partition_sql = self.get_partition_sql(**kwargs)

if not self._include_over:
return function_sql
sql = function_sql
if self._include_filter:
sql += " FILTER({filter_sql})".format(filter_sql=filter_sql)
if self._include_over:
sql += " OVER({partition_sql})".format(partition_sql=partition_sql)

return "{function_sql} OVER({partition_sql})".format(
function_sql=function_sql, partition_sql=partition_sql
)
return sql


class WindowFrameAnalyticFunction(AnalyticFunction):
Expand Down
20 changes: 20 additions & 0 deletions pypika/tests/test_analytic_queries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from pypika import (
Criterion,
JoinType,
Order,
Query,
Expand Down Expand Up @@ -220,6 +221,25 @@ def test_last_value(self):
str(q),
)

def test_filter(self):
expr = (
an.LastValue(self.table_abc.fizz)
.filter(Criterion.all([self.table_abc.bar == True]))
.over(self.table_abc.foo)
.orderby(self.table_abc.date)
)

q = Query.from_(self.table_abc).select(expr)

self.assertEqual(
"SELECT "
'LAST_VALUE("fizz") '
'FILTER(WHERE "bar"=true) '
'OVER(PARTITION BY "foo" ORDER BY "date") '
'FROM "abc"',
str(q),
)

def test_orderby_asc(self):
expr = (
an.LastValue(self.table_abc.fizz)
Expand Down

0 comments on commit 0393aa9

Please sign in to comment.