From 91257e96feeec4b92c9166a742a80ad5b76f5997 Mon Sep 17 00:00:00 2001 From: rizar Date: Tue, 22 Dec 2020 19:05:18 +0000 Subject: [PATCH 1/7] support distinct(bla) and FROM foo, bar --- process_sql.py | 30 ++++++++++++++---- .../test_parsing.cpython-37-pytest-6.2.1.pyc | Bin 0 -> 2415 bytes test/db.sqlite | Bin 0 -> 12288 bytes test/test_parsing.py | 28 ++++++++++++++++ 4 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc create mode 100644 test/db.sqlite create mode 100644 test/test_parsing.py diff --git a/process_sql.py b/process_sql.py index 778f854..1d43bf7 100644 --- a/process_sql.py +++ b/process_sql.py @@ -195,14 +195,23 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, column id """ - tok = toks[start_idx] - if tok == "*": - return start_idx + 1, schema.idMap[tok] + idx = start_idx + tok = toks[idx] + in_parentheses = False + col_id = None + + if tok == '(': + in_parentheses = True + idx += 1 + tok = toks[idx] + + if tok in ['1', '*']: + col_id = schema.idMap['*'] if "." in tok: # if token is a composite alias, col = tok.split(".") key = tables_with_alias[alias] + "." + col - return start_idx + 1, schema.idMap[key] + col_id = schema.idMap[key] assert ( default_tables is not None and len(default_tables) > 0 @@ -212,9 +221,16 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): table = tables_with_alias[alias] if tok in schema.schema[table]: key = table + "." + tok - return start_idx + 1, schema.idMap[key] + col_id = schema.idMap[key] + + assert col_id, "Error col: {}".format(tok) + + if in_parentheses: + assert toks[idx + 1] == ')' + return idx + 2, col_id + else: + return idx + 1, col_id - assert False, "Error col: {}".format(tok) def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): @@ -444,7 +460,7 @@ def parse_from(toks, start_idx, tables_with_alias, schema): idx, sql = parse_sql(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE["sql"], sql)) else: - if idx < len_ and toks[idx] == "join": + if idx < len_ and toks[idx] in [",", "join"]: idx += 1 # skip join idx, table_unit, table_name = parse_table_unit( toks, idx, tables_with_alias, schema diff --git a/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc b/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4474be6e2f06da5298c2cceeb5662776e17996d GIT binary patch literal 2415 zcmbVOPj4GV6rY)0uh)N*rlmAcAsD}aap29Y|0S_dl(qcZnKy4{e)IPI-tOm>N)f^J z+h?!nA2vdN>CWUL(0BrWxeLG$!x8eaxs%u;2&35cb6Ag%$cgiQK6ZUKF8BqoS8lIIFJ6uuOI^=PXI{Vo@Z*0$_mUfz47Bn{DFUtT5 zDa`OO{9rNC7*rLPs?HA9ZmSzA>*QM7S*85GVFHNdq?S#CTENpYX}*VdLx8I9&k!pZ>SB1^dslw7Nvc_$9fCMiwiu;?N2)v~A^rM`386Lb!f{r)R^CBcekH97-2C zyuh*yGcLp8&nb()uf~FvS>;Sjstfs?T#gSg7ZXM=uozyA(>O9q3z!gj@6t^EoiZW9 zxf&COAib9x+a7tX^Vzcmp4|o&3VGS%Lav95(q?1u!T#2lTbs?==H54r=GqrK`^}xk zW^--oyKB9pk33o1*r@5|15c*qj!61lMwRHR zLyx2ucxZSqs0LDUp^WE4@C4(-!Tp5{#v207GX$C-)4B7EPj~KE+l|5d<3$%7X#PBd zyXX97+s?arMWH2Yw3-&5^p1>{1Cd^%tssi1;62ewVo*F!3seSet}SU9Jn2C{j8xc_ zsS_pOEYn+59wz;WnOU?GB?m!7yFtv;>QfV*FzIdyk%;sHj=C6Nm@`4|Ar7f#M^jNSbFgVYTHl*DOCgZ8f!`yY**Iwi_aaKEcFb1H6!p5!6go|=JlpCnf4zwxWL>Aa zzi}WCKmY**5I_I{1Q0*~0R#|eqQENa^acZc2yVUg@1^?*pJnAYms#s^VJ6myojlBp zT-V)kZhU;TKaJVH4No;07Mt{EE};lpFV?)- zQQe!zrk_9n0R#|0009ILKmY**5I_Kd1`E*tH~4Zn8UhF)fB*srAb Date: Wed, 23 Dec 2020 00:52:55 +0000 Subject: [PATCH 2/7] support <> --- process_sql.py | 42 +++++++++--------- .../test_parsing.cpython-37-pytest-6.2.1.pyc | Bin 2415 -> 3129 bytes test/test_parsing.py | 22 ++++++--- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/process_sql.py b/process_sql.py index 1d43bf7..77e7c1a 100644 --- a/process_sql.py +++ b/process_sql.py @@ -41,20 +41,21 @@ ) JOIN_KEYWORDS = ("join", "on", "as") -WHERE_OPS = ( - "not", - "between", - "=", - ">", - "<", - ">=", - "<=", - "!=", - "in", - "like", - "is", - "exists", -) +WHERE_OPS = { + "not": 0, + "between": 1, + "=": 2, + ">": 3, + "<": 4, + ">=": 5, + "<=": 6, + "!=": 7, + "in": 8, + "like": 9, + "is": 10, + "exists": 11, + "<>": 7 +} UNIT_OPS = ("none", "-", "+", "*", "/") AGG_OPS = ("none", "max", "min", "count", "sum", "avg") TABLE_TYPE = { @@ -162,14 +163,14 @@ def tokenize(string): if toks[i] in vals: toks[i] = vals[toks[i]] - # find if there exists !=, >=, <= - eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] + # find if there exists !=, >=, <=, <>` + eq_idxs = [idx for idx, tok in enumerate(toks) if tok in ("=", ">")] eq_idxs.reverse() prefix = ("!", ">", "<") for eq_idx in eq_idxs: pre_tok = toks[eq_idx - 1] if pre_tok in prefix: - toks = toks[: eq_idx - 1] + [pre_tok + "="] + toks[eq_idx + 1 :] + toks = toks[: eq_idx - 1] + [pre_tok + toks[eq_idx]] + toks[eq_idx + 1 :] return toks @@ -375,12 +376,11 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N assert ( idx < len_ and toks[idx] in WHERE_OPS ), "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) - op_id = WHERE_OPS.index(toks[idx]) + op_id = WHERE_OPS[toks[idx]] idx += 1 val1 = val2 = None - if op_id == WHERE_OPS.index( - "between" - ): # between..and... special case: dual values + if op_id == WHERE_OPS['between']: + # between..and... special case: dual values idx, val1 = parse_value( toks, idx, tables_with_alias, schema, default_tables ) diff --git a/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc b/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc index f4474be6e2f06da5298c2cceeb5662776e17996d..e7ff8bf5fbbc58e598116f7032e3c48e47ad84ac 100644 GIT binary patch delta 523 zcma)&y-EW?6ov1d{ZC8`K|>^qqjQl+2;_Mck91gXMU%MF8eNK5 zEOZ7N5r(Uf3%3q$Fnl$E7lv5I#Bm=2Fx;x2FP@-weaY+);RG7&uU#X&;Rc=bc z@`XcsD<96GpU&s#b-4d{-+y^O{pBsT$j~HoN+{;Eg=a%Hl~Y_EjIEZF#5CWi-na!`a6`Ko=ta=d+n%VySBJB=CKQ>et|PH@36?5 zbUe##k5g7Hm*rlK)0)Gny?;>JqmB$w;*RSL$#mYJya09aprs5q)5#kZccz0GF=dU! QdZb`N9QqVfHKV5W2J(4_$N&HU delta 73 zcmdlf@m`3}iI8U00n=Kh "bar"')['where'] == ground_truth + assert get_sql(test_schema(), + 'SELECT * FROM papers WHERE papers.title != "bar"')['where'] == ground_truth From a1a80579d8099f15a0e35f0a16d5c4e824b22605 Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Dec 2020 14:09:59 +0000 Subject: [PATCH 3/7] support inner join --- process_sql.py | 5 +++++ test/test_parsing.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/process_sql.py b/process_sql.py index 77e7c1a..fe3e8d8 100644 --- a/process_sql.py +++ b/process_sql.py @@ -456,12 +456,17 @@ def parse_from(toks, start_idx, tables_with_alias, schema): isBlock = True idx += 1 + if toks[idx] == 'as': + raise DerivedTableAliasError() + if toks[idx] == "select": idx, sql = parse_sql(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE["sql"], sql)) else: if idx < len_ and toks[idx] in [",", "join"]: idx += 1 # skip join + if idx + 1 < len_ and toks[idx:idx + 2] == ["inner", "join"]: + idx += 2 # skip join idx, table_unit, table_name = parse_table_unit( toks, idx, tables_with_alias, schema ) diff --git a/test/test_parsing.py b/test/test_parsing.py index 80989c2..d7d73fc 100644 --- a/test/test_parsing.py +++ b/test/test_parsing.py @@ -19,11 +19,13 @@ def test_parse_col(): 'SELECT DISTINCT papers.id FROM papers')['select'] == ground_truth -def test_comma_joins(): +def test_joins(): ground_truth = {'conds': [], 'table_units': [('table_unit', '__papers__'), ('table_unit', '__coauthored__')]} assert get_sql(test_schema(), 'SELECT * FROM papers JOIN coauthored')['from'] == ground_truth + assert get_sql(test_schema(), + 'SELECT * FROM papers INNER JOIN coauthored')['from'] == ground_truth assert get_sql(test_schema(), 'SELECT * FROM papers, coauthored')['from'] == ground_truth From d30f6244aeca42d46be82691a0a65eab2df1d00c Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Dec 2020 14:11:14 +0000 Subject: [PATCH 4/7] improve exception handling --- process_sql.py | 19 +++++++++++++++++- .../test_parsing.cpython-37-pytest-6.2.1.pyc | Bin 3129 -> 3319 bytes 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/process_sql.py b/process_sql.py index fe3e8d8..535e971 100644 --- a/process_sql.py +++ b/process_sql.py @@ -68,6 +68,16 @@ ORDER_OPS = ("desc", "asc") +class DerivedFieldAliasError(ValueError): + pass + +class DerivedTableAliasError(ValueError): + pass + +class ParenthesesInConditionError(ValueError): + pass + + class Schema: """ Simple schema which maps table&column to a unique identifier @@ -224,7 +234,11 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): key = table + "." + tok col_id = schema.idMap[key] - assert col_id, "Error col: {}".format(tok) + if col_id is None: + if tok == 'as': + raise DerivedFieldAliasError(toks[idx + 1]) + else: + assert "Error col: {}".format(tok) if in_parentheses: assert toks[idx + 1] == ')' @@ -365,6 +379,9 @@ def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=N conds = [] while idx < len_: + if toks[idx] == '(': + raise ParenthesesInConditionError() + idx, val_unit = parse_val_unit( toks, idx, tables_with_alias, schema, default_tables ) diff --git a/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc b/test/__pycache__/test_parsing.cpython-37-pytest-6.2.1.pyc index e7ff8bf5fbbc58e598116f7032e3c48e47ad84ac..a7b48e13cc20288389e40eb9d6eb29dfe0c5e735 100644 GIT binary patch delta 192 zcmdlf@m-SFiI6F(^K9(myvj33e#@pN0A%EpFfCxOVOYpGfw8D;vOK#rBj4m`_8r<YnU;1fsdg`ZgL=_kPlhe2_F>(U{vS2fJ delta 97 zcmew^xl@AIiI6F(1?1%1>6|xx~l`09=|I A$p8QV From 55b74897f57a04208f731a8b6f86a23a2581f968 Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Dec 2020 14:49:01 +0000 Subject: [PATCH 5/7] add ValueListError --- process_sql.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/process_sql.py b/process_sql.py index 535e971..d153996 100644 --- a/process_sql.py +++ b/process_sql.py @@ -77,6 +77,9 @@ class DerivedTableAliasError(ValueError): class ParenthesesInConditionError(ValueError): pass +class ValueListError(ValueError): + pass + class Schema: """ @@ -367,6 +370,8 @@ def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None) idx = end_idx if isBlock: + if toks[idx] == ',': + raise ValueListError() assert toks[idx] == ")" idx += 1 From 3c60030eb02c58c39615dea4101b7d68b0325c8b Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Dec 2020 18:06:37 +0000 Subject: [PATCH 6/7] if --> elif for joins --- process_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/process_sql.py b/process_sql.py index d153996..cb123b3 100644 --- a/process_sql.py +++ b/process_sql.py @@ -487,7 +487,7 @@ def parse_from(toks, start_idx, tables_with_alias, schema): else: if idx < len_ and toks[idx] in [",", "join"]: idx += 1 # skip join - if idx + 1 < len_ and toks[idx:idx + 2] == ["inner", "join"]: + elif idx + 1 < len_ and toks[idx:idx + 2] == ["inner", "join"]: idx += 2 # skip join idx, table_unit, table_name = parse_table_unit( toks, idx, tables_with_alias, schema From 2623570122f51cd54e0229f8e7222dc52ac9834a Mon Sep 17 00:00:00 2001 From: rizar Date: Wed, 23 Dec 2020 19:12:46 +0000 Subject: [PATCH 7/7] assert message --> assert False, message --- process_sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/process_sql.py b/process_sql.py index cb123b3..7e0a813 100644 --- a/process_sql.py +++ b/process_sql.py @@ -241,7 +241,7 @@ def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): if tok == 'as': raise DerivedFieldAliasError(toks[idx + 1]) else: - assert "Error col: {}".format(tok) + assert False, "Error col: {}".format(tok) if in_parentheses: assert toks[idx + 1] == ')'