Skip to content

Commit

Permalink
Adds validation on group by fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedetefy committed Dec 9, 2021
1 parent cbbf1ca commit efab254
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 50 deletions.
27 changes: 23 additions & 4 deletions snuba/query/conditions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from typing import Any, Mapping, Optional, Sequence, Set
from typing import Any, Mapping, Optional, Sequence, Set, Union

from snuba.query.dsl import literals_tuple
from snuba.query.expressions import Expression, FunctionCall, Literal
from snuba.query.expressions import (
Expression,
FunctionCall,
Literal,
OptionalScalarType,
)
from snuba.query.matchers import Any as AnyPattern
from snuba.query.matchers import AnyExpression, AnyOptionalString
from snuba.query.matchers import Column as ColumnPattern
from snuba.query.matchers import FunctionCall as FunctionCallPattern
from snuba.query.matchers import Integer
from snuba.query.matchers import Literal as LiteralPattern
from snuba.query.matchers import Or, Param, Pattern, String
from snuba.query.matchers import SubscriptableReference as SubscriptableReferencePattern


class ConditionFunctions:
Expand Down Expand Up @@ -294,12 +300,25 @@ def is_condition(exp: Expression) -> bool:


def build_match(
col: str, ops: Sequence[str], param_type: Any, alias: Optional[str] = None
col: str,
ops: Sequence[str],
param_type: Any,
alias: Optional[str] = None,
key: OptionalScalarType = None,
) -> Or[Expression]:
# The IN condition has to be checked separately since each parameter
# has to be checked individually.
alias_match = AnyOptionalString() if alias is None else String(alias)
column_match = Param("column", ColumnPattern(alias_match, String(col)))
pattern: Union[ColumnPattern, SubscriptableReferencePattern]
if key is not None:
pattern = SubscriptableReferencePattern(
table_name=alias_match, column_name=String(col), key=AnyPattern(str)
)
else:
pattern = ColumnPattern(table_name=alias_match, column_name=String(col))

column_match = Param("column", pattern)

return Or(
[
FunctionCallPattern(
Expand Down
7 changes: 7 additions & 0 deletions snuba/query/matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class SubscriptableReference(Pattern[SubscriptableReferenceExpr]):
If column_name and key arguments are provided, they have to match, otherwise they are ignored.
"""

table_name: Optional[Pattern[Optional[str]]] = None
column_name: Optional[Pattern[str]] = None
key: Optional[Pattern[str]] = None

Expand All @@ -366,6 +367,12 @@ def match(self, node: AnyType) -> Optional[MatchResult]:

result = MatchResult()

if self.table_name is not None:
partial_result = self.table_name.match(node.column.table_name)
if partial_result is None:
return None
result = result.merge(partial_result)

if self.column_name is not None:
partial_result = self.column_name.match(node.column.column_name)
if partial_result is None:
Expand Down
5 changes: 4 additions & 1 deletion snuba/query/snql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,10 @@ def _align_max_days_date_align(

# If there is an = or IN condition on time, we don't need to do any of this
match = build_match(
entity.required_time_column, [ConditionFunctions.EQ], datetime, alias
col=entity.required_time_column,
ops=[ConditionFunctions.EQ],
param_type=datetime,
alias=alias,
)
if any(match.match(cond) for cond in old_top_level):
return old_top_level
Expand Down
56 changes: 52 additions & 4 deletions snuba/query/validation/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
get_first_level_and_conditions,
)
from snuba.query.exceptions import InvalidExpressionException, InvalidQueryException
from snuba.query.expressions import Column
from snuba.query.expressions import SubscriptableReference as SubscriptableReferenceExpr

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +62,9 @@ def validate(self, query: Query, alias: Optional[str] = None) -> None:
missing = set()
if self.required_columns:
for col in self.required_columns:
match = build_match(col, [ConditionFunctions.EQ], int, alias)
match = build_match(
col=col, ops=[ConditionFunctions.EQ], param_type=int, alias=alias
)
found = any(match.match(cond) for cond in top_level)
if not found:
missing.add(col)
Expand Down Expand Up @@ -115,15 +119,15 @@ class NoTimeBasedConditionValidator(QueryValidator):
def __init__(self, required_time_column: str) -> None:
self.required_time_column = required_time_column
self.match = build_match(
required_time_column,
[
col=required_time_column,
ops=[
ConditionFunctions.EQ,
ConditionFunctions.LT,
ConditionFunctions.LTE,
ConditionFunctions.GT,
ConditionFunctions.GTE,
],
datetime,
param_type=datetime,
)

def validate(self, query: Query, alias: Optional[str] = None) -> None:
Expand All @@ -150,6 +154,47 @@ def __init__(
self.max_allowed_aggregations = max_allowed_aggregations
self.disallowed_aggregations = disallowed_aggregations

@staticmethod
def _validate_groupby_fields_have_matching_conditions(
query: Query, alias: Optional[str] = None
) -> None:
"""
Method that insures that for every field in the group by clause, there should be a
matching a condition. For example, if we had in our groupby clause [project_id, tags[3]],
we should have the following conditions in the where clause `project_id = 3 AND tags[3]
IN array(1,2,3)`. This is necessary because we want to avoid the case where an
unspecified number of buckets is returned.
"""
condition = query.get_condition()
top_level = get_first_level_and_conditions(condition) if condition else []

for exp in query.get_groupby():
key = None
if isinstance(exp, SubscriptableReferenceExpr):
column_name = str(exp.column.column_name)
key = exp.key.value
elif isinstance(exp, Column):
column_name = exp.column_name
else:
raise InvalidQueryException(
"Unhandled column type in group by validation"
)

match = build_match(
col=column_name,
ops=[ConditionFunctions.EQ],
param_type=int,
alias=alias,
key=key,
)
found = any(match.match(cond) for cond in top_level)

if not found:
raise InvalidQueryException(
f"Every field in groupby must have a corresponding condition in "
f"where clause. missing condition for field {exp}"
)

def validate(self, query: Query, alias: Optional[str] = None,) -> None:
selected = query.get_selected_columns()
if len(selected) > self.max_allowed_aggregations:
Expand All @@ -168,6 +213,9 @@ def validate(self, query: Query, alias: Optional[str] = None,) -> None:
f"invalid clause {field} in subscription query"
)

if "groupby" not in self.disallowed_aggregations:
self._validate_groupby_fields_have_matching_conditions(query, alias)


class GranularityValidator(QueryValidator):
""" Verify that the given granularity is a multiple of the configured value """
Expand Down
53 changes: 48 additions & 5 deletions tests/datasets/validation/test_subscription_clauses_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,22 @@ class EntityKeySubscription(EntitySubscriptionValidation, EntitySubscription):
binary_condition(
BooleanFunctions.AND,
binary_condition(
"equals",
Column("_snuba_project_id", None, "project_id"),
Literal(None, 1),
BooleanFunctions.AND,
binary_condition(
"equals",
Column("_snuba_project_id", None, "project_id"),
Literal(None, 1),
),
binary_condition(
"equals",
Column("_snuba_org_id", None, "org_id"),
Literal(None, 1),
),
),
binary_condition(
"equals",
Column("_snuba_org_id", None, "org_id"),
Literal(None, 1),
Column("_snuba_tags[3]", None, "tags[3]"),
Literal(None, 2),
),
),
),
Expand Down Expand Up @@ -155,6 +163,41 @@ def test_subscription_clauses_validation(query: LogicalQuery) -> None:
),
id="no orderby clauses",
),
pytest.param(
LogicalQuery(
QueryEntity(
EntityKey.METRICS_COUNTERS,
get_entity(EntityKey.METRICS_COUNTERS).get_data_model(),
),
selected_columns=[SelectedExpression("value", Column(None, None, "value"))],
condition=binary_condition(
BooleanFunctions.AND,
binary_condition(
ConditionFunctions.EQ,
Column(None, None, "metric_id"),
Literal(None, 123),
),
binary_condition(
BooleanFunctions.AND,
binary_condition(
"equals",
Column("_snuba_project_id", None, "project_id"),
Literal(None, 1),
),
binary_condition(
"equals",
Column("_snuba_org_id", None, "org_id"),
Literal(None, 1),
),
),
),
groupby=[
Column("_snuba_project_id", None, "project_id"),
Column("_snuba_tags[3]", None, "tags[3]"),
],
),
id="tags[3] is in the group by clause but has no matching condition",
),
]


Expand Down
22 changes: 19 additions & 3 deletions tests/query/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
),
(
"subscriptable match with column",
SubscriptableReference(String("stuff")),
SubscriptableReference(column_name=String("stuff")),
SubscriptableReferenceExpr(
None, ColumnExpr(None, None, "stuff"), LiteralExpr(None, "things")
),
Expand All @@ -318,20 +318,36 @@
),
(
"subscriptable match with column and key",
SubscriptableReference(String("stuff"), String("things")),
SubscriptableReference(column_name=String("stuff"), key=String("things")),
SubscriptableReferenceExpr(
None, ColumnExpr(None, None, "stuff"), LiteralExpr(None, "things")
),
MatchResult(),
),
(
"subscriptable match with wrong column and key",
SubscriptableReference(String("notstuff"), String("things")),
SubscriptableReference(column_name=String("notstuff"), key=String("things")),
SubscriptableReferenceExpr(
None, ColumnExpr(None, None, "stuff"), LiteralExpr(None, "things")
),
None,
),
(
"Matches a column with all fields",
SubscriptableReference(
table_name=Param("p_table_name", AnyOptionalString()),
column_name=Param("col_name", String("notstuff")),
key=Param("key", String("things")),
),
SubscriptableReferenceExpr(
alias=None,
column=ColumnExpr(None, "table_name", "notstuff"),
key=LiteralExpr(None, "things"),
),
MatchResult(
{"p_table_name": "table_name", "col_name": "notstuff", "key": "things"}
),
),
]


Expand Down
24 changes: 17 additions & 7 deletions tests/subscriptions/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_sessions_subscription_task_result_encoder() -> None:
ScheduledSubscriptionTask(
timestamp,
SubscriptionWithTick(
EntityKey.EVENTS,
EntityKey.SESSIONS,
Subscription(
SubscriptionIdentifier(PartitionId(1), uuid.uuid1()),
subscription_data,
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_metrics_subscription_task_result_encoder() -> None:
query=(
"""
MATCH (metrics_counters) SELECT sum(value) AS value BY project_id, tags[3]
WHERE org_id = 1 AND project_id IN array(1) AND metric_id = 7
WHERE org_id = 1 AND project_id IN array(1) AND metric_id = 7 AND tags[3] IN array(1,2)
"""
),
time_window=timedelta(minutes=1),
Expand All @@ -303,21 +303,31 @@ def test_metrics_subscription_task_result_encoder() -> None:
}

task_result = SubscriptionTaskResult(
ScheduledTask(
ScheduledSubscriptionTask(
timestamp,
Subscription(
SubscriptionIdentifier(PartitionId(1), uuid.uuid1()), subscription_data,
SubscriptionWithTick(
EntityKey.METRICS_COUNTERS,
Subscription(
SubscriptionIdentifier(PartitionId(1), uuid.uuid1()),
subscription_data,
),
Tick(
0,
Interval(1, 5),
Interval(datetime(1970, 1, 1), datetime(1970, 1, 2)),
),
),
),
(request, result),
)

message = codec.encode(task_result)
data = json.loads(message.value.decode("utf-8"))
assert data["version"] == 2
payload = data["payload"]

assert payload["subscription_id"] == str(task_result.task.task.identifier)
assert payload["subscription_id"] == str(
task_result.task.task.subscription.identifier
)
assert payload["request"] == request.body
assert payload["result"] == result
assert payload["timestamp"] == task_result.task.timestamp.isoformat()
Expand Down
Loading

0 comments on commit efab254

Please sign in to comment.