Skip to content

Commit

Permalink
refactor(Helpers): Simplify predicates logic
Browse files Browse the repository at this point in the history
  • Loading branch information
geido committed Oct 3, 2024
1 parent dbe34d8 commit 0ac0335
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 29 deletions.
26 changes: 10 additions & 16 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,23 +1321,17 @@ def get_rls_filters_for_column(
rls_filters_for_column = []

if rls_filters := self.get_sqla_row_level_filters(template_processor):
parsable_statement = "SELECT 1 WHERE"

for rls_filter in rls_filters:
parsable_statement += f" {str(rls_filter)} "
if rls_filter != rls_filters[-1]:
parsable_statement += "AND"

predicates = SQLStatement(
parsable_statement, engine=self.db_engine_spec.engine
).get_predicates()

for predicate in predicates:
column = predicate.find(exp.Column)
if column and column.output_name == column_name:
rls_filters_for_column.append(
for rls in rls_filters:
predicates = SQLStatement(
str(rls), engine=self.db_engine_spec.engine
).get_predicates()
rls_filters_for_column.extend(
[
TextClause(str(predicate.sql(comments=False)))
)
for predicate in predicates
if predicate.find(exp.Column).output_name == column_name
]
)

return rls_filters_for_column

Expand Down
35 changes: 22 additions & 13 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def is_mutating(self) -> bool:
"""
raise NotImplementedError()

def get_predicates(self) -> list[exp.Predicate]:
"""
Return the predicates for a SQL statement.
>>> statement = SQLStatement("SELECT * FROM table WHERE column = 'value'")
>>> statement.get_predicates()
["COLUMN = 'value'"]
"""
raise NotImplementedError()

def __str__(self) -> str:
return self.format()

Expand Down Expand Up @@ -369,24 +379,13 @@ def get_settings(self) -> dict[str, str | bool]:

def get_predicates(self) -> list[exp.Predicate]:
"""
Return the predicates for the SQL statement.
Return the predicates for a SQL statement.
>>> statement = SQLStatement("SELECT * FROM table WHERE column = 'value'")
>>> statement.get_predicates()
["COLUMN = 'value'"]
"""
predicates = []
where_clauses = self._parsed.find_all(exp.Where)

if not where_clauses:
return []

for where_clause in where_clauses:
where_predicates = where_clause.find_all(exp.Predicate)
for pred in where_predicates:
predicates.append(pred)

return predicates
return self._parsed.find_all(exp.Predicate)


class KQLSplitState(enum.Enum):
Expand Down Expand Up @@ -544,6 +543,16 @@ def is_mutating(self) -> bool:
"""
return self._parsed.startswith(".") and not self._parsed.startswith(".show")

def get_predicates(self) -> list[exp.Predicate]:
"""
Return the predicates for a SQL statement.
>>> statement = KustoKQLStatement("StormEvents | where InjuriesDirect > 50")
>>> statement.get_predicates()
["InjuriesDirect > 50"]
"""
return NotImplementedError()


class SQLScript:
"""
Expand Down

0 comments on commit 0ac0335

Please sign in to comment.