Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/rule_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2563,6 +2563,11 @@ def variablize_all_subtrees(rule: dict, subtrees: list) -> dict:
new_rule_rewrite_json = json.loads(new_rule['rewrite_json'])

for subtree in subtrees:
# Prevents already variablized subtrees from being variablized again in an infinite loop
# ex. {"value": "V22"} should not be variablized again
if len(subtree) == 1 and 'value' in subtree and QueryRewriter.is_var(subtree['value']):
Copy link

Copilot AI Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition len(subtree) == 1 assumes a specific structure for variablized subtrees. This magic number should be documented or replaced with a more explicit check that describes what constitutes a variablized subtree structure.

Suggested change
if len(subtree) == 1 and 'value' in subtree and QueryRewriter.is_var(subtree['value']):
# Explicitly check for a variablized subtree structure: a dict with only the 'value' key whose value is a variable
if isinstance(subtree, dict) and set(subtree.keys()) == {'value'} and QueryRewriter.is_var(subtree['value']):

Copilot uses AI. Check for mistakes.
continue

# Find a variable name for the given subtree
#
new_rule_mapping, newVarInternal = RuleGenerator.findNextVarInternal(new_rule_mapping)
Expand Down
11 changes: 11 additions & 0 deletions data/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@
'database': 'postgresql'
},

{
'id': 91,
'key': 'aggregation_to_filtered_subquery',
'name': 'Aggregation to Filtered Subquery',
'pattern': '''SELECT <x2>.<x9>, DATE(<x2>.<x3>), CASE WHEN SUM(CASE WHEN <x2>.<x4> = <x5> THEN <x5> ELSE <x6> END) >= <x5> THEN <x5> ELSE <x6> END FROM <x8> AS <x2> GROUP BY <<x7>>, DATE(<x2>.<x3>)''',
'constraints': '',
'rewrite': '''SELECT <x2>.<x9>, <x2>.<x3> FROM (SELECT <x9>, DATE(<x3>) FROM <x8> WHERE <x4> = <x5>) AS <x2> GROUP BY <<x7>>, <x2>.<x3>''',
'actions': '',
'database': 'postgresql'
},

{
'id': 8090,
'key': 'test_rule_wetune_90',
Expand Down
34 changes: 34 additions & 0 deletions tests/test_query_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,40 @@ def test_over_partial_matching():
_q1, _rewrite_path = QueryRewriter.rewrite(q0, rules)
assert format(parse(q1)) == format(parse(_q1))

def test_rewrite_aggregation_to_subquery():
q0 = '''
SELECT
t1.CPF,
DATE(t1.data) AS data,
CASE WHEN SUM(CASE WHEN t1.login_ok = true
THEN 1
ELSE 0
END) >= 1
THEN true
ELSE false
END
FROM db_risco.site_rn_login AS t1
GROUP BY t1.CPF, DATE(t1.data)
'''
q1 = '''
SELECT
t1.CPF,
t1.data
FROM (
SELECT
CPF,
DATE(data)
FROM db_risco.site_rn_login
WHERE login_ok = true
) t1
GROUP BY t1.CPF, t1.data
'''

rule_keys = ['aggregation_to_filtered_subquery']
rules = [get_rule(k) for k in rule_keys]
_q1, _rewrite_path = QueryRewriter.rewrite(q0, rules)
assert format(parse(q1)) == format(parse(_q1))


# TODO - TBI
#
Expand Down
38 changes: 38 additions & 0 deletions tests/test_rule_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,9 +1935,14 @@ def test_generate_general_rule_8():

assert StringUtil.strim(RuleGenerator._fingerPrint(rule['pattern'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<x1> AS DATE)
''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<<y>> AS DATE)
'''))

assert StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
<x1>
''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
<<y>>
'''))


Comment on lines 1936 to 1948
Copy link

Copilot AI Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test logic is duplicated and confusing with overlapping assertions. The first assertion checks both pattern and rewrite fingerprints, while the second assertion only checks rewrite fingerprints. Consider restructuring these assertions to be more clear and avoid the duplicated rewrite fingerprint check.

Suggested change
assert StringUtil.strim(RuleGenerator._fingerPrint(rule['pattern'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<x1> AS DATE)
''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<<y>> AS DATE)
'''))
assert StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
<x1>
''')) or StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite'])) == StringUtil.strim(RuleGenerator._fingerPrint('''
<<y>>
'''))
pattern_fp = StringUtil.strim(RuleGenerator._fingerPrint(rule['pattern']))
rewrite_fp = StringUtil.strim(RuleGenerator._fingerPrint(rule['rewrite']))
expected_pattern_fps = (
StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<x1> AS DATE)
''')),
StringUtil.strim(RuleGenerator._fingerPrint('''
CAST(<<y>> AS DATE)
'''))
)
expected_rewrite_fps = (
StringUtil.strim(RuleGenerator._fingerPrint('''
<x1>
''')),
StringUtil.strim(RuleGenerator._fingerPrint('''
<<y>>
'''))
)
assert pattern_fp in expected_pattern_fps
assert rewrite_fp in expected_rewrite_fps

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -2248,6 +2253,39 @@ def test_generate_general_rule_21():
assert q0_rule == "FROM <x1> NATURAL JOIN (<x2>) WHERE <<x3>> AND <x1>.<x4> = 4"
assert q1_rule == "FROM <x1> INNER JOIN <x2> ON <x1>.<x4> = <x2>.<x4> WHERE <<x3>>"

def test_generate_general_rule_22():
q0 = """SELECT
t1.CPF,
DATE(t1.data),
CASE WHEN SUM(CASE WHEN t1.login_ok = true
THEN 1
ELSE 0
END) >= 1
THEN true
ELSE false
END
FROM db_risco.site_rn_login AS t1
GROUP BY t1.CPF, DATE(t1.data)"""

q1 = """SELECT
t1.CPF,
t1.data
FROM (
SELECT
CPF,
DATE(data)
FROM db_risco.site_rn_login
WHERE login_ok = true
) t1
GROUP BY t1.CPF, t1.data"""

rule = RuleGenerator.generate_general_rule(q0, q1)
assert type(rule) is dict

q0_rule, q1_rule = unify_variable_names(rule['pattern'], rule['rewrite'])
assert q0_rule == "SELECT <<x1>>, DATE(<x2>.<x3>), CASE WHEN SUM(CASE WHEN <x2>.<x4> = <x5> THEN <x5> ELSE <x6> END) >= <x5> THEN <x5> ELSE <x6> END FROM <x2> GROUP BY <<x7>>, DATE(<x2>.<x3>)"
assert q1_rule == "SELECT <<x1>>, <x2>.<x3> FROM (SELECT <x8>, DATE(<x3>) FROM <x2> WHERE <x4> = <x5>) AS t1 GROUP BY <<x7>>, <x2>.<x3>"


# def test_suggest_rules_bf_1():
# examples = [
Expand Down