Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: helper functions for RLS #19055

Merged
merged 9 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
remove_quotes,
Token,
TokenList,
Where,
)
from sqlparse.tokens import (
CTE,
Expand Down Expand Up @@ -458,3 +459,204 @@ def validate_filter_clause(clause: str) -> None:
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")


class InsertRLSState(str, Enum):
"""
State machine that scans for WHERE and ON clauses referencing tables.
"""

SCANNING = "SCANNING"
SEEN_SOURCE = "SEEN_SOURCE"
FOUND_TABLE = "FOUND_TABLE"


def has_table_query(token_list: TokenList) -> bool:
"""
Return if a stament has a query reading from a table.

>>> has_table_query(sqlparse.parse("COUNT(*)")[0])
False
>>> has_table_query(sqlparse.parse("SELECT * FROM table")[0])
True

Note that queries reading from constant values return false:

>>> has_table_query(sqlparse.parse("SELECT * FROM (SELECT 1)")[0])
False

"""
state = InsertRLSState.SCANNING
for token in token_list.tokens:

# # Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, sqlparse.sql.Identifier) or token.ttype == Keyword
):
return True

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

return False


def add_table_name(rls: TokenList, table: str) -> None:
betodealmeida marked this conversation as resolved.
Show resolved Hide resolved
"""
Modify a RLS expression ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely could use flatten here. It uses a generator so likely a copy should be made given you're mutating the tokens, i.e.,

for token in list(rls.flatten()):
    if imt(token, i=Identifier) and token.get_parent_name() is None:
        ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue, if we call .flatten() we would never get an Identifier.

token = tokens.pop(0)

if isinstance(token, Identifier) and token.get_parent_name() is None:
token.tokens = [
Token(Name, table),
Token(Punctuation, "."),
Token(Name, token.get_name()),
]
elif isinstance(token, TokenList):
tokens.extend(token.tokens)


def matches_table_name(candidate: Token, table: str) -> bool:
"""
Returns if the token represents a reference to the table.

Tables can be fully qualified with periods.

Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
"""
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])

target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])

# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False

return True


def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
"""
# make sure the identifier has the table name
add_table_name(rls, table)

state = InsertRLSState.SCANNING
for token in token_list.tokens:

# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)

# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
state = InsertRLSState.SEEN_SOURCE

# Found identifier/keyword after FROM/JOIN, test for table
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
state = InsertRLSState.FOUND_TABLE

# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[
Token(Punctuation, ")"),
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
]
+ rls.tokens
)
state = InsertRLSState.SCANNING

# Found ON clause, insert RLS. The logic for ON is more complicated than the logic
# for WHERE because in the former the comparisons are siblings, while on the
# latter they are children.
elif (
state == InsertRLSState.FOUND_TABLE
and token.ttype == Keyword
and token.value.upper() == "ON"
):
tokens = [
Token(Whitespace, " "),
rls,
Token(Whitespace, " "),
Token(Keyword, "AND"),
Token(Whitespace, " "),
Token(Punctuation, "("),
]
i = token_list.tokens.index(token)
token.parent.tokens[i + 1 : i + 1] = tokens
i += len(tokens) + 2

# close parenthesis after last existing comparison
j = 0
for j, sibling in enumerate(token_list.tokens[i:]):
# scan until we hit a non-comparison keyword (like ORDER BY) or a WHERE
if (
sibling.ttype == Keyword
and not imt(
sibling, m=[(Keyword, "AND"), (Keyword, "OR"), (Keyword, "NOT")]
)
or isinstance(sibling, Where)
):
j -= 1
break
token.parent.tokens[i + j + 1 : i + j + 1] = [
Token(Whitespace, " "),
Token(Punctuation, ")"),
Token(Whitespace, " "),
]

state = InsertRLSState.SCANNING

# Found table but no WHERE clause found, insert one
elif state == InsertRLSState.FOUND_TABLE and token.ttype != Whitespace:
i = token_list.tokens.index(token)
token_list.tokens[i:i] = [
Token(Whitespace, " "),
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
Token(Whitespace, " "),
]

state = InsertRLSState.SCANNING

# Found nothing, leaving source
elif state == InsertRLSState.SEEN_SOURCE and token.ttype != Whitespace:
state = InsertRLSState.SCANNING

# found table at the end of the statement; append a WHERE clause
if state == InsertRLSState.FOUND_TABLE:
token_list.tokens.extend(
[
Token(Whitespace, " "),
Where([Token(Keyword, "WHERE"), Token(Whitespace, " "), rls]),
]
)

return token_list
Loading