Skip to content

Commit aef9bad

Browse files
Various additional query builder fixes
1 parent 37f732e commit aef9bad

File tree

6 files changed

+156
-62
lines changed

6 files changed

+156
-62
lines changed

elasticsearch/esql/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from ..dsl import E # noqa: F401
1819
from .esql import ESQL, and_, not_, or_ # noqa: F401

elasticsearch/esql/esql.py

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import json
19+
import re
1920
from abc import ABC, abstractmethod
2021
from typing import Any, Dict, Optional, Tuple, Type, Union
2122

@@ -111,6 +112,29 @@ def render(self) -> str:
111112
def _render_internal(self) -> str:
112113
pass
113114

115+
@staticmethod
116+
def _format_index(index: IndexType) -> str:
117+
return index._index._name if hasattr(index, "_index") else str(index)
118+
119+
@staticmethod
120+
def _format_id(id: FieldType, allow_patterns=False) -> str:
121+
s = str(id) # in case it is an InstrumentedField
122+
if allow_patterns and "*" in s:
123+
return s # patterns cannot be escaped
124+
if re.fullmatch(r"[a-zA-Z_@][a-zA-Z0-9_\.]*", s):
125+
return s
126+
# this identifier needs to be escaped
127+
s.replace("`", "``")
128+
return f"`{s}`"
129+
130+
@staticmethod
131+
def _format_expr(expr: ExpressionType) -> str:
132+
return (
133+
json.dumps(expr)
134+
if not isinstance(expr, (str, InstrumentedExpression))
135+
else str(expr)
136+
)
137+
114138
def _is_forked(self) -> bool:
115139
if self.__class__.__name__ == "Fork":
116140
return True
@@ -427,7 +451,7 @@ def sample(self, probability: float) -> "Sample":
427451
"""
428452
return Sample(self, probability)
429453

430-
def sort(self, *columns: FieldType) -> "Sort":
454+
def sort(self, *columns: ExpressionType) -> "Sort":
431455
"""The ``SORT`` processing command sorts a table on one or more columns.
432456
433457
:param columns: The columns to sort on.
@@ -570,15 +594,12 @@ def metadata(self, *fields: FieldType) -> "From":
570594
return self
571595

572596
def _render_internal(self) -> str:
573-
indices = [
574-
index if isinstance(index, str) else index._index._name
575-
for index in self._indices
576-
]
597+
indices = [self._format_index(index) for index in self._indices]
577598
s = f'{self.__class__.__name__.upper()} {", ".join(indices)}'
578599
if self._metadata_fields:
579600
s = (
580601
s
581-
+ f' METADATA {", ".join([str(field) for field in self._metadata_fields])}'
602+
+ f' METADATA {", ".join([self._format_id(field) for field in self._metadata_fields])}'
582603
)
583604
return s
584605

@@ -594,7 +615,11 @@ class Row(ESQLBase):
594615
def __init__(self, **params: ExpressionType):
595616
super().__init__()
596617
self._params = {
597-
k: json.dumps(v) if not isinstance(v, InstrumentedExpression) else v
618+
self._format_id(k): (
619+
json.dumps(v)
620+
if not isinstance(v, InstrumentedExpression)
621+
else self._format_expr(v)
622+
)
598623
for k, v in params.items()
599624
}
600625

@@ -615,7 +640,7 @@ def __init__(self, item: str):
615640
self._item = item
616641

617642
def _render_internal(self) -> str:
618-
return f"SHOW {self._item}"
643+
return f"SHOW {self._format_id(self._item)}"
619644

620645

621646
class Branch(ESQLBase):
@@ -667,11 +692,11 @@ def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint":
667692
return self
668693

669694
def _render_internal(self) -> str:
670-
key = "" if not self._key else f" ON {self._key}"
695+
key = "" if not self._key else f" ON {self._format_id(self._key)}"
671696
names = (
672697
""
673698
if not self._type_name and not self._pvalue_name
674-
else f' AS {self._type_name or "type"}, {self._pvalue_name or "pvalue"}'
699+
else f' AS {self._format_id(self._type_name) or "type"}, {self._format_id(self._pvalue_name) or "pvalue"}'
675700
)
676701
return f"CHANGE_POINT {self._value}{key}{names}"
677702

@@ -709,12 +734,13 @@ def with_(self, inference_id: str) -> "Completion":
709734
def _render_internal(self) -> str:
710735
if self._inference_id is None:
711736
raise ValueError("The completion command requires an inference ID")
737+
with_ = {"inference_id": self._inference_id}
712738
if self._named_prompt:
713739
column = list(self._named_prompt.keys())[0]
714740
prompt = list(self._named_prompt.values())[0]
715-
return f"COMPLETION {column} = {prompt} WITH {self._inference_id}"
741+
return f"COMPLETION {self._format_id(column)} = {self._format_id(prompt)} WITH {json.dumps(with_)}"
716742
else:
717-
return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}"
743+
return f"COMPLETION {self._format_id(self._prompt[0])} WITH {json.dumps(with_)}"
718744

719745

720746
class Dissect(ESQLBase):
@@ -742,9 +768,13 @@ def append_separator(self, separator: str) -> "Dissect":
742768

743769
def _render_internal(self) -> str:
744770
sep = (
745-
"" if self._separator is None else f' APPEND_SEPARATOR="{self._separator}"'
771+
""
772+
if self._separator is None
773+
else f" APPEND_SEPARATOR={json.dumps(self._separator)}"
774+
)
775+
return (
776+
f"DISSECT {self._format_id(self._input)} {json.dumps(self._pattern)}{sep}"
746777
)
747-
return f"DISSECT {self._input} {json.dumps(self._pattern)}{sep}"
748778

749779

750780
class Drop(ESQLBase):
@@ -760,7 +790,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
760790
self._columns = columns
761791

762792
def _render_internal(self) -> str:
763-
return f'DROP {", ".join([str(col) for col in self._columns])}'
793+
return f'DROP {", ".join([self._format_id(col, allow_patterns=True) for col in self._columns])}'
764794

765795

766796
class Enrich(ESQLBase):
@@ -814,12 +844,18 @@ def with_(self, *fields: FieldType, **named_fields: FieldType) -> "Enrich":
814844
return self
815845

816846
def _render_internal(self) -> str:
817-
on = "" if self._match_field is None else f" ON {self._match_field}"
847+
on = (
848+
""
849+
if self._match_field is None
850+
else f" ON {self._format_id(self._match_field)}"
851+
)
818852
with_ = ""
819853
if self._named_fields:
820-
with_ = f' WITH {", ".join([f"{name} = {field}" for name, field in self._named_fields.items()])}'
854+
with_ = f' WITH {", ".join([f"{self._format_id(name)} = {self._format_id(field)}" for name, field in self._named_fields.items()])}'
821855
elif self._fields is not None:
822-
with_ = f' WITH {", ".join([str(field) for field in self._fields])}'
856+
with_ = (
857+
f' WITH {", ".join([self._format_id(field) for field in self._fields])}'
858+
)
823859
return f"ENRICH {self._policy}{on}{with_}"
824860

825861

@@ -832,7 +868,10 @@ class Eval(ESQLBase):
832868
"""
833869

834870
def __init__(
835-
self, parent: ESQLBase, *columns: FieldType, **named_columns: FieldType
871+
self,
872+
parent: ESQLBase,
873+
*columns: ExpressionType,
874+
**named_columns: ExpressionType,
836875
):
837876
if columns and named_columns:
838877
raise ValueError(
@@ -844,10 +883,13 @@ def __init__(
844883
def _render_internal(self) -> str:
845884
if isinstance(self._columns, dict):
846885
cols = ", ".join(
847-
[f"{name} = {value}" for name, value in self._columns.items()]
886+
[
887+
f"{self._format_id(name)} = {self._format_expr(value)}"
888+
for name, value in self._columns.items()
889+
]
848890
)
849891
else:
850-
cols = ", ".join([f"{col}" for col in self._columns])
892+
cols = ", ".join([f"{self._format_expr(col)}" for col in self._columns])
851893
return f"EVAL {cols}"
852894

853895

@@ -900,7 +942,7 @@ def __init__(self, parent: ESQLBase, input: FieldType, pattern: str):
900942
self._pattern = pattern
901943

902944
def _render_internal(self) -> str:
903-
return f"GROK {self._input} {json.dumps(self._pattern)}"
945+
return f"GROK {self._format_id(self._input)} {json.dumps(self._pattern)}"
904946

905947

906948
class Keep(ESQLBase):
@@ -916,7 +958,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
916958
self._columns = columns
917959

918960
def _render_internal(self) -> str:
919-
return f'KEEP {", ".join([f"{col}" for col in self._columns])}'
961+
return f'KEEP {", ".join([f"{self._format_id(col, allow_patterns=True)}" for col in self._columns])}'
920962

921963

922964
class Limit(ESQLBase):
@@ -932,7 +974,7 @@ def __init__(self, parent: ESQLBase, max_number_of_rows: int):
932974
self._max_number_of_rows = max_number_of_rows
933975

934976
def _render_internal(self) -> str:
935-
return f"LIMIT {self._max_number_of_rows}"
977+
return f"LIMIT {json.dumps(self._max_number_of_rows)}"
936978

937979

938980
class LookupJoin(ESQLBase):
@@ -967,7 +1009,9 @@ def _render_internal(self) -> str:
9671009
if isinstance(self._lookup_index, str)
9681010
else self._lookup_index._index._name
9691011
)
970-
return f"LOOKUP JOIN {index} ON {self._field}"
1012+
return (
1013+
f"LOOKUP JOIN {self._format_index(index)} ON {self._format_id(self._field)}"
1014+
)
9711015

9721016

9731017
class MvExpand(ESQLBase):
@@ -983,7 +1027,7 @@ def __init__(self, parent: ESQLBase, column: FieldType):
9831027
self._column = column
9841028

9851029
def _render_internal(self) -> str:
986-
return f"MV_EXPAND {self._column}"
1030+
return f"MV_EXPAND {self._format_id(self._column)}"
9871031

9881032

9891033
class Rename(ESQLBase):
@@ -999,7 +1043,7 @@ def __init__(self, parent: ESQLBase, **columns: FieldType):
9991043
self._columns = columns
10001044

10011045
def _render_internal(self) -> str:
1002-
return f'RENAME {", ".join([f"{old_name} AS {new_name}" for old_name, new_name in self._columns.items()])}'
1046+
return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}'
10031047

10041048

10051049
class Sample(ESQLBase):
@@ -1015,7 +1059,7 @@ def __init__(self, parent: ESQLBase, probability: float):
10151059
self._probability = probability
10161060

10171061
def _render_internal(self) -> str:
1018-
return f"SAMPLE {self._probability}"
1062+
return f"SAMPLE {json.dumps(self._probability)}"
10191063

10201064

10211065
class Sort(ESQLBase):
@@ -1026,12 +1070,16 @@ class Sort(ESQLBase):
10261070
in a single expression.
10271071
"""
10281072

1029-
def __init__(self, parent: ESQLBase, *columns: FieldType):
1073+
def __init__(self, parent: ESQLBase, *columns: ExpressionType):
10301074
super().__init__(parent)
10311075
self._columns = columns
10321076

10331077
def _render_internal(self) -> str:
1034-
return f'SORT {", ".join([f"{col}" for col in self._columns])}'
1078+
sorts = [
1079+
" ".join([self._format_id(term) for term in str(col).split(" ")])
1080+
for col in self._columns
1081+
]
1082+
return f'SORT {", ".join([f"{sort}" for sort in sorts])}'
10351083

10361084

10371085
class Stats(ESQLBase):
@@ -1062,14 +1110,17 @@ def by(self, *grouping_expressions: ExpressionType) -> "Stats":
10621110

10631111
def _render_internal(self) -> str:
10641112
if isinstance(self._expressions, dict):
1065-
exprs = [f"{key} = {value}" for key, value in self._expressions.items()]
1113+
exprs = [
1114+
f"{self._format_id(key)} = {self._format_expr(value)}"
1115+
for key, value in self._expressions.items()
1116+
]
10661117
else:
1067-
exprs = [f"{expr}" for expr in self._expressions]
1118+
exprs = [f"{self._format_expr(expr)}" for expr in self._expressions]
10681119
expression_separator = ",\n "
10691120
by = (
10701121
""
10711122
if self._grouping_expressions is None
1072-
else f'\n BY {", ".join([f"{expr}" for expr in self._grouping_expressions])}'
1123+
else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}'
10731124
)
10741125
return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}'
10751126

@@ -1087,7 +1138,7 @@ def __init__(self, parent: ESQLBase, *expressions: ExpressionType):
10871138
self._expressions = expressions
10881139

10891140
def _render_internal(self) -> str:
1090-
return f'WHERE {" AND ".join([f"{expr}" for expr in self._expressions])}'
1141+
return f'WHERE {" AND ".join([f"{self._format_expr(expr)}" for expr in self._expressions])}'
10911142

10921143

10931144
def and_(*expressions: InstrumentedExpression) -> "InstrumentedExpression":

0 commit comments

Comments
 (0)