Skip to content

Commit

Permalink
fix: Support negation and subqueries in whereclauseextractor (#25816)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
timgl and github-actions[bot] authored Oct 25, 2024
1 parent 74fbc2d commit ff712be
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 44 deletions.
12 changes: 12 additions & 0 deletions posthog/hogql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,18 @@ class CompareOperationOp(StrEnum):
NotIRegex = "!~*"


NEGATED_COMPARE_OPS: list[CompareOperationOp] = [
CompareOperationOp.NotEq,
CompareOperationOp.NotLike,
CompareOperationOp.NotILike,
CompareOperationOp.NotIn,
CompareOperationOp.GlobalNotIn,
CompareOperationOp.NotInCohort,
CompareOperationOp.NotRegex,
CompareOperationOp.NotIRegex,
]


@dataclass(kw_only=True)
class CompareOperation(Expr):
left: Expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ def test_person_properties_andor_11(self):
expected = _expr("properties.email = 'jimmy@posthog.com'")
assert actual == expected

def test_person_array(self):
actual = self.get_clause("SELECT * FROM events WHERE person.properties.email IN ['jimmy@posthog.com']")
expected = _expr("properties.email IN ['jimmy@posthog.com']")
assert actual == expected

def test_person_properties_function_calls(self):
actual = self.get_clause(
"SELECT * FROM events WHERE properties.email = 'bla@posthog.com' and toString(person.properties.email) = 'jimmy@posthog.com'"
Expand All @@ -170,6 +175,16 @@ def test_person_properties_function_call_args_complex(self):
)
assert actual is None

def test_left_join_with_negation(self):
actual = self.get_clause("SELECT * FROM events WHERE person.properties.email != 'jimmy@posthog.com'")
assert actual is None

def test_subquery(self):
actual = self.print_query(
"SELECT * FROM events WHERE person.id IN (select person_id from person_distinct_ids where distinct_id = '1')"
)
assert "in(id, (SELECT person_distinct_ids.person_id" in actual

def test_boolean(self):
PropertyDefinition.objects.get_or_create(
team=self.team,
Expand All @@ -179,6 +194,6 @@ def test_boolean(self):
)
actual = self.print_query("SELECT * FROM events WHERE person.properties.person_boolean = false")
assert (
f"FROM person WHERE and(equals(person.team_id, {self.team.id}), ifNull(equals(transform(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties"
f"ifNull(equals(transform(toString(replaceRegexpAll(nullIf(nullIf(JSONExtractRaw(person.properties"
in actual
)
77 changes: 37 additions & 40 deletions posthog/hogql/database/schema/util/where_clause_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from posthog.hogql.ast import CompareOperationOp, ArithmeticOperationOp
from posthog.hogql.context import HogQLContext
from posthog.hogql.database.models import DatabaseField, LazyJoinToAdd, LazyTableToAdd
from posthog.hogql.errors import NotImplementedError
from posthog.hogql.errors import NotImplementedError, QueryError
from posthog.hogql.functions.mapping import HOGQL_COMPARISON_MAPPING

from posthog.hogql.visitor import clone_expr, CloningVisitor, Visitor, TraversingVisitor

Expand Down Expand Up @@ -43,6 +44,7 @@ class WhereClauseExtractor(CloningVisitor):
clear_types: bool = False
clear_locations: bool = False
capture_timestamp_comparisons: bool = False # implement handle_timestamp_comparison if setting this to True
is_join: bool = False
tracked_tables: list[ast.LazyTable | ast.LazyJoin]
tombstone_string: str

Expand Down Expand Up @@ -76,6 +78,9 @@ def get_inner_where(self, select_query: ast.SelectQuery) -> Optional[ast.Expr]:
if not select_query.where and not select_query.prewhere:
return None

if select_query.select_from and select_query.select_from.next_join:
self.is_join = True

# visit the where clause
wheres = []
if select_query.where:
Expand Down Expand Up @@ -110,25 +115,34 @@ def visit_compare_operation(self, node: ast.CompareOperation) -> ast.Expr:
if result:
return result

# if it's a join, and if the comparison is negative, we don't want to filter down as the outer join might end up doing other comparisons that clash
if self.is_join and node.op in ast.NEGATED_COMPARE_OPS:
return ast.Constant(value=True)

# Check if any of the fields are a field on our requested table
if len(self.tracked_tables) > 0:
left = self.visit(node.left)
right = self.visit(node.right)

if isinstance(node.right, ast.SelectQuery):
right = clone_expr(
node.right, clear_types=False, clear_locations=False, inline_subquery_field_names=True
)
else:
right = self.visit(node.right)

if has_tombstone(left, self.tombstone_string) or has_tombstone(right, self.tombstone_string):
return ast.Constant(value=self.tombstone_string)
return ast.CompareOperation(op=node.op, left=left, right=right)

return ast.Constant(value=True)

def visit_select_query(self, node: ast.SelectQuery) -> ast.Expr:
# going too deep, bail
return ast.Constant(value=True)

def visit_arithmetic_operation(self, node: ast.ArithmeticOperation) -> ast.Expr:
# don't even try to handle complex logic
return ast.Constant(value=True)

def visit_not(self, node: ast.Not) -> ast.Expr:
if self.is_join:
return ast.Constant(value=True)
response = self.visit(node.expr)
if has_tombstone(response, self.tombstone_string):
return ast.Constant(value=self.tombstone_string)
Expand All @@ -139,41 +153,21 @@ def visit_call(self, node: ast.Call) -> ast.Expr:
return self.visit_and(ast.And(exprs=node.args))
elif node.name == "or":
return self.visit_or(ast.Or(exprs=node.args))
elif node.name == "greaterOrEquals":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.GtEq, left=node.args[0], right=node.args[1])
)
elif node.name == "greater":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.Gt, left=node.args[0], right=node.args[1])
)
elif node.name == "lessOrEquals":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.LtEq, left=node.args[0], right=node.args[1])
)
elif node.name == "less":
elif node.name == "not":
if self.is_join:
return ast.Constant(value=True)

elif node.name in HOGQL_COMPARISON_MAPPING:
op = HOGQL_COMPARISON_MAPPING[node.name]
if len(node.args) != 2:
raise QueryError(f"Comparison '{node.name}' requires exactly two arguments")
# We do "cleverer" logic with nullable types in visit_compare_operation
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.Lt, left=node.args[0], right=node.args[1])
)
elif node.name == "equals":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.Eq, left=node.args[0], right=node.args[1])
)
elif node.name == "like":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.Like, left=node.args[0], right=node.args[1])
)
elif node.name == "notLike":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.NotLike, left=node.args[0], right=node.args[1])
)
elif node.name == "ilike":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.ILike, left=node.args[0], right=node.args[1])
)
elif node.name == "notIlike":
return self.visit_compare_operation(
ast.CompareOperation(op=CompareOperationOp.NotILike, left=node.args[0], right=node.args[1])
ast.CompareOperation(
left=node.args[0],
right=node.args[1],
op=op,
)
)
args = [self.visit(arg) for arg in node.args]
if any(has_tombstone(arg, self.tombstone_string) for arg in args):
Expand Down Expand Up @@ -459,6 +453,9 @@ def visit_alias(self, node: ast.Alias) -> bool:
def visit_tuple(self, node: ast.Tuple) -> bool:
return all(self.visit(arg) for arg in node.exprs)

def visit_array(self, node: ast.Tuple) -> bool:
return all(self.visit(arg) for arg in node.exprs)


def is_simple_timestamp_field_expression(expr: ast.Expr, context: HogQLContext, tombstone_string: str) -> bool:
result = IsSimpleTimestampFieldExpressionVisitor(context, tombstone_string).visit(expr)
Expand Down
20 changes: 17 additions & 3 deletions posthog/hogql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from posthog.hogql.errors import BaseHogQLError


def clone_expr(expr: Expr, clear_types=False, clear_locations=False) -> Expr:
def clone_expr(expr: Expr, clear_types=False, clear_locations=False, inline_subquery_field_names=False) -> Expr:
"""Clone an expression node."""
return CloningVisitor(clear_types=clear_types, clear_locations=clear_locations).visit(expr)
return CloningVisitor(
clear_types=clear_types,
clear_locations=clear_locations,
inline_subquery_field_names=inline_subquery_field_names,
).visit(expr)


def clear_locations(expr: Expr) -> Expr:
Expand Down Expand Up @@ -350,9 +354,11 @@ def __init__(
self,
clear_types: Optional[bool] = True,
clear_locations: Optional[bool] = False,
inline_subquery_field_names: Optional[bool] = False,
):
self.clear_types = clear_types
self.clear_locations = clear_locations
self.inline_subquery_field_names = inline_subquery_field_names

def visit_cte(self, node: ast.CTE):
return ast.CTE(
Expand Down Expand Up @@ -489,12 +495,20 @@ def visit_constant(self, node: ast.Constant):
)

def visit_field(self, node: ast.Field):
return ast.Field(
field = ast.Field(
start=None if self.clear_locations else node.start,
end=None if self.clear_locations else node.end,
type=None if self.clear_types else node.type,
chain=node.chain.copy(),
)
if (
self.inline_subquery_field_names
and isinstance(node.type, ast.PropertyType)
and node.type.joined_subquery is not None
and node.type.joined_subquery_field_name is not None
):
field.chain = [node.type.joined_subquery_field_name]
return field

def visit_placeholder(self, node: ast.Placeholder):
return ast.Placeholder(
Expand Down

0 comments on commit ff712be

Please sign in to comment.