Skip to content

Commit 1aa8ff8

Browse files
committed
Fix adql translation for 'contains' and 'intersects'
1 parent f320b66 commit 1aa8ff8

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

src/queryparser/adql/ADQLParser.g4

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ comp_op: EQ | NOT_EQ | LTH | GTH | GRET | LEET ;
4141
comparison_predicate: value_expression comp_op value_expression ;
4242
concatenation_operator: CONCAT ;
4343
contains: CONTAINS LPAREN geometry_value_expression COMMA geometry_value_expression RPAREN ;
44-
contains_predicate: INT EQ contains;
44+
contains_predicate: INT EQ contains | contains EQ INT;
4545
coord_sys: string_value_expression ;
4646
coord_value: point_value | column_reference ;
4747
//coord1: COORD1 LPAREN coord_value RPAREN ;
@@ -80,7 +80,7 @@ in_predicate: value_expression ( NOT )? IN in_predicate_value
8080
in_predicate_value: table_subquery | LPAREN in_value_list RPAREN ;
8181
in_value_list: value_expression ( COMMA value_expression )* ;
8282
intersects: INTERSECTS LPAREN geometry_value_expression COMMA geometry_value_expression RPAREN ;
83-
intersects_predicate: INT EQ intersects;
83+
intersects_predicate: INT EQ intersects | intersects EQ INT ;
8484
join_column_list: column_name_list ;
8585
join_condition: ON search_condition ;
8686
join_specification: join_condition | named_columns_join ;

src/queryparser/adql/adqltranslator.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,17 @@
2222
'SQRT', 'TAN', 'TRUNCATE')
2323

2424

25-
def _remove_children(ctx):
25+
def _removeFirstChild(ctx):
26+
if ctx.children is not None:
27+
del ctx.children[0]
28+
29+
30+
def _remove_children(ctx, reverse=False):
2631
for _ in range(ctx.getChildCount() - 1):
27-
ctx.removeLastChild()
32+
if reverse:
33+
_removeFirstChild(ctx)
34+
else:
35+
ctx.removeLastChild()
2836

2937

3038
def _convert_values(ctx, cidx, output_sql):
@@ -268,8 +276,6 @@ def visitCircle(self, ctx):
268276
def visitPolygon(self, ctx):
269277
pars = []
270278

271-
for ch in ctx.children:
272-
print(ch.getText())
273279
for j in range(2, len(ctx.children), 2):
274280
par = self._convert_values(ctx, j)
275281
# only append coordinates
@@ -351,18 +357,27 @@ def visitCentroid(self, ctx):
351357
'''
352358

353359
def visitContains_predicate(self, ctx):
354-
comp_value = ctx.children[0].getText()
355-
if comp_value == '1' or comp_value == '0':
360+
comp_value_l = ctx.children[0].getText()
361+
comp_value_r = ctx.children[2].getText()
362+
if comp_value_l == '1' or comp_value_l == '0':
356363
self.visitContains(ctx.children[2])
357364
ctx_text = self.contexts[ctx.children[2]]
358365
if self.output_sql == 'mysql':
359-
ctx_text = f"{comp_value} = {ctx_text}"
360-
elif self.output_sql == 'postgresql' and comp_value == '0':
366+
ctx_text = f"{comp_value_l} = {ctx_text}"
367+
elif self.output_sql == 'postgresql' and comp_value_l == '0':
368+
ctx_text = ctx_text.replace('@', '!@')
369+
_remove_children(ctx)
370+
elif comp_value_r == '1' or comp_value_r == '0':
371+
self.visitContains(ctx.children[0])
372+
ctx_text = self.contexts[ctx.children[0]]
373+
if self.output_sql == 'mysql':
374+
ctx_text = f"{comp_value_r} = {ctx_text}"
375+
elif self.output_sql == 'postgresql' and comp_value_r == '0':
361376
ctx_text = ctx_text.replace('@', '!@')
377+
_remove_children(ctx, reverse=True)
362378
else:
363379
raise QueryError('The function CONTAINS allows comparison to 1 or 0 only.')
364380

365-
_remove_children(ctx)
366381
self.contexts[ctx] = ctx_text
367382

368383

@@ -415,21 +430,32 @@ def visitDistance(self, ctx):
415430
ctx.removeLastChild()
416431
self.contexts[ctx] = ctx_text
417432

433+
418434
def visitIntersects_predicate(self, ctx):
419-
comp_value = ctx.children[0].getText()
420-
if comp_value == '1' or comp_value == '0':
435+
comp_value_l = ctx.children[0].getText()
436+
comp_value_r = ctx.children[2].getText()
437+
if comp_value_l == '1' or comp_value_l == '0':
421438
self.visitIntersects(ctx.children[2])
422439
ctx_text = self.contexts[ctx.children[2]]
423440
if self.output_sql == 'mysql':
424-
ctx_text = f"{comp_value} = {ctx_text}"
425-
elif self.output_sql == 'postgresql' and comp_value == '0':
441+
ctx_text = f"{comp_value_l} = {ctx_text}"
442+
elif self.output_sql == 'postgresql' and comp_value_l == '0':
426443
ctx_text = ctx_text.replace('&&', '!&&')
444+
_remove_children(ctx)
445+
elif comp_value_r == '1' or comp_value_r == '0':
446+
self.visitIntersects(ctx.children[0])
447+
ctx_text = self.contexts[ctx.children[0]]
448+
if self.output_sql == 'mysql':
449+
ctx_text = f"{comp_value_r} = {ctx_text}"
450+
elif self.output_sql == 'postgresql' and comp_value_r == '0':
451+
ctx_text = ctx_text.replace('&&', '!&&')
452+
_remove_children(ctx, reverse=True)
427453
else:
428454
raise QueryError('The function INTERSECTS allows comparison to 1 or 0 only.')
429455

430-
_remove_children(ctx)
431456
self.contexts[ctx] = ctx_text
432457

458+
433459
def visitIntersects(self, ctx):
434460
arg = (self.contexts[ctx.children[2].children[0]],
435461
self.contexts[ctx.children[4].children[0]])

src/queryparser/testing/tests.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,10 @@ adql_postgresql_tests:
847847
- SELECT TOP 10 ra, dec FROM db.tab WHERE 1=CONTAINS(POINT('ICRS', ra, dec), BOX('ICRS', -3.0, 5.0, 4.0, 10.0));
848848
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ sbox(spoint(RADIANS(-5.000000000000),RADIANS(0.000000000000)),spoint(RADIANS(-1.000000000000),RADIANS(10.000000000000))) LIMIT 10;
849849

850+
-
851+
- SELECT TOP 10 ra, dec FROM db.tab WHERE CONTAINS(POINT('ICRS', ra, dec), BOX('ICRS', -3.0, 5.0, 4.0, 10.0)) = 1;
852+
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ sbox(spoint(RADIANS(-5.000000000000),RADIANS(0.000000000000)),spoint(RADIANS(-1.000000000000),RADIANS(10.000000000000))) LIMIT 10;
853+
850854
-
851855
- SELECT TOP 10 ra, dec FROM db.tab WHERE 1=CONTAINS(POINT(ra, dec), BOX(-3.0, 5.0, 4.0, 10.0));
852856
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ sbox(spoint(RADIANS(-5.000000000000),RADIANS(0.000000000000)),spoint(RADIANS(-1.000000000000),RADIANS(10.000000000000))) LIMIT 10;
@@ -855,6 +859,10 @@ adql_postgresql_tests:
855859
- SELECT TOP 10 ra, dec FROM db.tab WHERE 1=CONTAINS(POINT(ra, dec), CIRCLE(POINT(-3.0, 4.0), 10.0));
856860
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ scircle(spoint(RADIANS(-3.0), RADIANS(4.0)), RADIANS(10.0)) LIMIT 10;
857861

862+
-
863+
- SELECT TOP 10 ra, dec FROM db.tab WHERE CONTAINS(POINT(ra, dec), CIRCLE(POINT(-3.0, 4.0), 10.0))=0;
864+
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) !@ scircle(spoint(RADIANS(-3.0), RADIANS(4.0)), RADIANS(10.0)) LIMIT 10;
865+
858866
-
859867
- SELECT TOP 10 ra, dec FROM db.tab WHERE 0=CONTAINS(POINT(ra, dec), CIRCLE(POINT(-3.0, 4.0), 10.0));
860868
- SELECT ra, dec FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) !@ scircle(spoint(RADIANS(-3.0), RADIANS(4.0)), RADIANS(10.0)) LIMIT 10;
@@ -864,7 +872,7 @@ adql_postgresql_tests:
864872
- SELECT LOG(ra), LN(dec) FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ spoly('{(10.0d,-10.5d),(20.0d,20.5d),(30.0d,30.5d)}') LIMIT 10;
865873

866874
-
867-
- SELECT TOP 10 LOG10(ra), LOG(dec) FROM db.tab WHERE 1=CONTAINS(POINT(ra, dec), POLYGON(10.0, -10.5, 20.0, 20.5, 30.0, 30.5));
875+
- SELECT TOP 10 LOG10(ra), LOG(dec) FROM db.tab WHERE CONTAINS(POINT(ra, dec), POLYGON(10.0, -10.5, 20.0, 20.5, 30.0, 30.5))=1;
868876
- SELECT LOG(ra), LN(dec) FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) @ spoly('{(10.0d,-10.5d),(20.0d,20.5d),(30.0d,30.5d)}') LIMIT 10;
869877

870878
-
@@ -883,6 +891,9 @@ adql_postgresql_tests:
883891
- SELECT TOP 10 LOG10(ra), LOG(dec) FROM db.tab WHERE 0=INTERSECTS(POINT('ICRS', ra, dec), POLYGON('ICRS', 10.0, -10.5, 20.0, 20.5, 30.0, 30.5));
884892
- SELECT LOG(ra), LN(dec) FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) !&& spoly('{(10.0d,-10.5d),(20.0d,20.5d),(30.0d,30.5d)}') LIMIT 10;
885893

894+
-
895+
- SELECT TOP 10 LOG10(ra), LOG(dec) FROM db.tab WHERE INTERSECTS(POINT('ICRS', ra, dec), POLYGON('ICRS', 10.0, -10.5, 20.0, 20.5, 30.0, 30.5)) = 0;
896+
- SELECT LOG(ra), LN(dec) FROM db.tab WHERE spoint(RADIANS(ra), RADIANS(dec)) !&& spoly('{(10.0d,-10.5d),(20.0d,20.5d),(30.0d,30.5d)}') LIMIT 10;
886897

887898

888899
# Each test below consists of:

0 commit comments

Comments
 (0)