Skip to content

Commit 3f630cd

Browse files
author
Peng Ren
committed
Added classic insert statement with column list and works in SQLAlchemy
1 parent a5231cb commit 3f630cd

13 files changed

+1125
-171
lines changed

pymongosql/sql/ast.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def visitInsertStatement(self, ctx: PartiQLParser.InsertStatementContext) -> Any
137137
self._insert_parse_result = InsertParseResult.for_visitor()
138138
handler = self._handlers.get("insert")
139139
if handler:
140-
return handler.handle_visitor(ctx, self._insert_parse_result)
140+
handler.handle_visitor(ctx, self._insert_parse_result)
141+
# Continue visiting children to process columnList and values
142+
self.visitChildren(ctx)
143+
return self._insert_parse_result
141144
return self.visitChildren(ctx)
142145

143146
def visitInsertStatementLegacy(self, ctx: PartiQLParser.InsertStatementLegacyContext) -> Any:
@@ -151,6 +154,22 @@ def visitInsertStatementLegacy(self, ctx: PartiQLParser.InsertStatementLegacyCon
151154
return handler.handle_visitor(ctx, self._insert_parse_result)
152155
return self.visitChildren(ctx)
153156

157+
def visitColumnList(self, ctx: PartiQLParser.ColumnListContext) -> Any:
158+
"""Handle column list in INSERT statements."""
159+
if self._current_operation == "insert":
160+
handler = self._handlers.get("insert")
161+
if handler:
162+
return handler.handle_column_list(ctx, self._insert_parse_result)
163+
return self.visitChildren(ctx)
164+
165+
def visitValues(self, ctx: PartiQLParser.ValuesContext) -> Any:
166+
"""Handle VALUES clause in INSERT statements."""
167+
if self._current_operation == "insert":
168+
handler = self._handlers.get("insert")
169+
if handler:
170+
return handler.handle_values(ctx, self._insert_parse_result)
171+
return self.visitChildren(ctx)
172+
154173
def visitFromClauseSimpleExplicit(self, ctx: PartiQLParser.FromClauseSimpleExplicitContext) -> Any:
155174
"""Handle FROM clause (explicit form) in DELETE statements."""
156175
if self._current_operation == "delete":

pymongosql/sql/insert_handler.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,17 @@ def can_handle(self, ctx: Any) -> bool:
3939
def handle_visitor(self, ctx: Any, parse_result: InsertParseResult) -> InsertParseResult:
4040
try:
4141
collection = self._extract_collection(ctx)
42-
value_text = self._extract_value_text(ctx)
4342

43+
# Check if this is a VALUES clause INSERT (new syntax)
44+
if hasattr(ctx, "values") and ctx.values():
45+
_logger.debug("Processing INSERT with VALUES clause")
46+
parse_result.collection = collection
47+
parse_result.insert_type = "values"
48+
# Return parse_result - visitor will call handle_column_list and handle_values
49+
return parse_result
50+
51+
# Otherwise, handle legacy value expression INSERT
52+
value_text = self._extract_value_text(ctx)
4453
documents = self._parse_value_expr(value_text)
4554
param_style, param_count = self._detect_parameter_style(documents)
4655

@@ -58,6 +67,137 @@ def handle_visitor(self, ctx: Any, parse_result: InsertParseResult) -> InsertPar
5867
parse_result.error_message = str(exc)
5968
return parse_result
6069

70+
def handle_column_list(self, ctx: Any, parse_result: InsertParseResult) -> Optional[List[str]]:
71+
"""Extract column names from columnList context."""
72+
_logger.debug("InsertHandler processing column list")
73+
try:
74+
columns = []
75+
column_names = ctx.columnName()
76+
if column_names:
77+
if not isinstance(column_names, list):
78+
column_names = [column_names]
79+
for col_name_ctx in column_names:
80+
# columnName contains a symbolPrimitive
81+
symbol = col_name_ctx.symbolPrimitive()
82+
if symbol:
83+
columns.append(symbol.getText())
84+
parse_result.insert_columns = columns
85+
_logger.debug(f"Extracted columns for INSERT: {columns}")
86+
return columns
87+
except Exception as e:
88+
_logger.warning(f"Error processing column list: {e}")
89+
parse_result.has_errors = True
90+
parse_result.error_message = str(e)
91+
return None
92+
93+
def handle_values(self, ctx: Any, parse_result: InsertParseResult) -> Optional[List[List[Any]]]:
94+
"""Extract value rows from VALUES clause."""
95+
_logger.debug("InsertHandler processing VALUES clause")
96+
try:
97+
rows = []
98+
value_rows = ctx.valueRow()
99+
if value_rows:
100+
if not isinstance(value_rows, list):
101+
value_rows = [value_rows]
102+
for value_row_ctx in value_rows:
103+
row_values = self._extract_value_row(value_row_ctx)
104+
rows.append(row_values)
105+
106+
parse_result.insert_values = rows
107+
108+
# Convert rows to documents
109+
columns = parse_result.insert_columns
110+
documents = self._convert_rows_to_documents(columns, rows)
111+
parse_result.insert_documents = documents
112+
113+
# Detect parameter style
114+
param_style, param_count = self._detect_parameter_style(documents)
115+
parse_result.parameter_style = param_style
116+
parse_result.parameter_count = param_count
117+
parse_result.has_errors = False
118+
parse_result.error_message = None
119+
120+
_logger.debug(f"Extracted {len(rows)} value rows for INSERT")
121+
return rows
122+
except Exception as e:
123+
_logger.warning(f"Error processing VALUES clause: {e}")
124+
parse_result.has_errors = True
125+
parse_result.error_message = str(e)
126+
return None
127+
128+
def _extract_value_row(self, value_row_ctx: Any) -> List[Any]:
129+
"""Extract values from a single valueRow."""
130+
row_values = []
131+
exprs = value_row_ctx.expr()
132+
if exprs:
133+
if not isinstance(exprs, list):
134+
exprs = [exprs]
135+
for expr_ctx in exprs:
136+
value = self._parse_expression_value(expr_ctx)
137+
row_values.append(value)
138+
return row_values
139+
140+
def _parse_expression_value(self, expr_ctx: Any) -> Any:
141+
"""Parse a single expression value from the parse tree."""
142+
if not expr_ctx:
143+
return None
144+
145+
text = expr_ctx.getText()
146+
147+
# Handle NULL
148+
if text.upper() == "NULL":
149+
return None
150+
151+
# Handle boolean literals
152+
if text.upper() == "TRUE":
153+
return True
154+
if text.upper() == "FALSE":
155+
return False
156+
157+
# Handle string literals (quoted)
158+
if (text.startswith("'") and text.endswith("'")) or (text.startswith('"') and text.endswith('"')):
159+
return text[1:-1]
160+
161+
# Handle numeric literals
162+
try:
163+
if "." in text:
164+
return float(text)
165+
return int(text)
166+
except ValueError:
167+
pass
168+
169+
# Handle parameters (? or :name)
170+
if text == "?":
171+
return "?"
172+
if text.startswith(":"):
173+
return text
174+
175+
# For complex expressions, return as-is
176+
return text
177+
178+
def _convert_rows_to_documents(self, columns: Optional[List[str]], rows: List[List[Any]]) -> List[Dict[str, Any]]:
179+
"""Convert rows to MongoDB documents."""
180+
documents = []
181+
182+
for row in rows:
183+
doc = {}
184+
185+
if columns:
186+
# Use explicit column names
187+
if len(row) != len(columns):
188+
raise ValueError(f"Column count ({len(columns)}) does not match value count ({len(row)})")
189+
190+
for col, val in zip(columns, row):
191+
doc[col] = val
192+
else:
193+
# Generate automatic column names (col0, col1, ...)
194+
for idx, val in enumerate(row):
195+
doc[f"col{idx}"] = val
196+
197+
documents.append(doc)
198+
199+
return documents
200+
61201
def _extract_collection(self, ctx: Any) -> str:
62202
if hasattr(ctx, "symbolPrimitive") and ctx.symbolPrimitive():
63203
return ctx.symbolPrimitive().getText()

pymongosql/sql/partiql/PartiQLLexer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated from PartiQLLexer.g4 by ANTLR 4.13.2
1+
# Generated from PartiQLLexer.g4 by ANTLR 4.13.0
22
from antlr4 import *
33
from io import StringIO
44
import sys
@@ -1678,7 +1678,7 @@ class PartiQLLexer(Lexer):
16781678

16791679
def __init__(self, input=None, output:TextIO = sys.stdout):
16801680
super().__init__(input, output)
1681-
self.checkVersion("4.13.2")
1681+
self.checkVersion("4.13.0")
16821682
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
16831683
self._actions = None
16841684
self._predicates = None

pymongosql/sql/partiql/PartiQLParser.g4

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ insertCommandReturning
166166

167167
// See the Grammar at https://github.com/partiql/partiql-docs/blob/main/RFCs/0011-partiql-insert.md#2-proposed-grammar-and-semantics
168168
insertStatement
169-
: INSERT INTO symbolPrimitive asIdent? value=expr onConflict?
170-
| INSERT INTO symbolPrimitive columnList? values onConflict?
169+
: INSERT INTO symbolPrimitive columnList? values onConflict?
170+
| INSERT INTO symbolPrimitive asIdent? value=expr onConflict?
171171
;
172172

173173
columnList

pymongosql/sql/partiql/PartiQLParser.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated from PartiQLParser.g4 by ANTLR 4.13.2
1+
# Generated from PartiQLParser.g4 by ANTLR 4.13.0
22
# encoding: utf-8
33
from antlr4 import *
44
from io import StringIO
@@ -303,21 +303,21 @@ def serializedATN():
303303
513,511,1,0,0,0,513,514,1,0,0,0,514,516,1,0,0,0,515,517,3,66,33,
304304
0,516,515,1,0,0,0,516,517,1,0,0,0,517,519,1,0,0,0,518,520,3,86,43,
305305
0,519,518,1,0,0,0,519,520,1,0,0,0,520,57,1,0,0,0,521,522,5,112,0,
306-
0,522,523,5,117,0,0,523,525,3,12,6,0,524,526,3,6,3,0,525,524,1,0,
307-
0,0,525,526,1,0,0,0,526,527,1,0,0,0,527,529,3,188,94,0,528,530,3,
308-
62,31,0,529,528,1,0,0,0,529,530,1,0,0,0,530,542,1,0,0,0,531,532,
309-
5,112,0,0,532,533,5,117,0,0,533,535,3,12,6,0,534,536,3,60,30,0,535,
310-
534,1,0,0,0,535,536,1,0,0,0,536,537,1,0,0,0,537,539,3,220,110,0,
311-
538,540,3,62,31,0,539,538,1,0,0,0,539,540,1,0,0,0,540,542,1,0,0,
312-
0,541,521,1,0,0,0,541,531,1,0,0,0,542,59,1,0,0,0,543,544,5,294,0,
313-
0,544,549,3,24,12,0,545,546,5,270,0,0,546,548,3,24,12,0,547,545,
314-
1,0,0,0,548,551,1,0,0,0,549,547,1,0,0,0,549,550,1,0,0,0,550,552,
315-
1,0,0,0,551,549,1,0,0,0,552,553,5,295,0,0,553,61,1,0,0,0,554,555,
316-
5,147,0,0,555,557,5,244,0,0,556,558,3,68,34,0,557,556,1,0,0,0,557,
317-
558,1,0,0,0,558,559,1,0,0,0,559,560,3,72,36,0,560,63,1,0,0,0,561,
318-
562,5,112,0,0,562,563,5,117,0,0,563,564,3,46,23,0,564,565,5,218,
319-
0,0,565,568,3,188,94,0,566,567,5,13,0,0,567,569,3,188,94,0,568,566,
320-
1,0,0,0,568,569,1,0,0,0,569,571,1,0,0,0,570,572,3,66,33,0,571,570,
306+
0,522,523,5,117,0,0,523,525,3,12,6,0,524,526,3,60,30,0,525,524,1,
307+
0,0,0,525,526,1,0,0,0,526,527,1,0,0,0,527,529,3,220,110,0,528,530,
308+
3,62,31,0,529,528,1,0,0,0,529,530,1,0,0,0,530,542,1,0,0,0,531,532,
309+
5,112,0,0,532,533,5,117,0,0,533,535,3,12,6,0,534,536,3,6,3,0,535,
310+
534,1,0,0,0,535,536,1,0,0,0,536,537,1,0,0,0,537,539,3,188,94,0,538,
311+
540,3,62,31,0,539,538,1,0,0,0,539,540,1,0,0,0,540,542,1,0,0,0,541,
312+
521,1,0,0,0,541,531,1,0,0,0,542,59,1,0,0,0,543,544,5,294,0,0,544,
313+
549,3,24,12,0,545,546,5,270,0,0,546,548,3,24,12,0,547,545,1,0,0,
314+
0,548,551,1,0,0,0,549,547,1,0,0,0,549,550,1,0,0,0,550,552,1,0,0,
315+
0,551,549,1,0,0,0,552,553,5,295,0,0,553,61,1,0,0,0,554,555,5,147,
316+
0,0,555,557,5,244,0,0,556,558,3,68,34,0,557,556,1,0,0,0,557,558,
317+
1,0,0,0,558,559,1,0,0,0,559,560,3,72,36,0,560,63,1,0,0,0,561,562,
318+
5,112,0,0,562,563,5,117,0,0,563,564,3,46,23,0,564,565,5,218,0,0,
319+
565,568,3,188,94,0,566,567,5,13,0,0,567,569,3,188,94,0,568,566,1,
320+
0,0,0,568,569,1,0,0,0,569,571,1,0,0,0,570,572,3,66,33,0,571,570,
321321
1,0,0,0,571,572,1,0,0,0,572,65,1,0,0,0,573,574,5,147,0,0,574,575,
322322
5,244,0,0,575,576,5,225,0,0,576,577,3,188,94,0,577,578,5,245,0,0,
323323
578,579,5,250,0,0,579,67,1,0,0,0,580,581,5,294,0,0,581,586,3,12,
@@ -1384,7 +1384,7 @@ class PartiQLParser ( Parser ):
13841384

13851385
def __init__(self, input:TokenStream, output:TextIO = sys.stdout):
13861386
super().__init__(input, output)
1387-
self.checkVersion("4.13.2")
1387+
self.checkVersion("4.13.0")
13881388
self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache)
13891389
self._predicates = None
13901390

@@ -3969,24 +3969,24 @@ def symbolPrimitive(self):
39693969
return self.getTypedRuleContext(PartiQLParser.SymbolPrimitiveContext,0)
39703970

39713971

3972-
def expr(self):
3973-
return self.getTypedRuleContext(PartiQLParser.ExprContext,0)
3972+
def values(self):
3973+
return self.getTypedRuleContext(PartiQLParser.ValuesContext,0)
39743974

39753975

3976-
def asIdent(self):
3977-
return self.getTypedRuleContext(PartiQLParser.AsIdentContext,0)
3976+
def columnList(self):
3977+
return self.getTypedRuleContext(PartiQLParser.ColumnListContext,0)
39783978

39793979

39803980
def onConflict(self):
39813981
return self.getTypedRuleContext(PartiQLParser.OnConflictContext,0)
39823982

39833983

3984-
def values(self):
3985-
return self.getTypedRuleContext(PartiQLParser.ValuesContext,0)
3984+
def expr(self):
3985+
return self.getTypedRuleContext(PartiQLParser.ExprContext,0)
39863986

39873987

3988-
def columnList(self):
3989-
return self.getTypedRuleContext(PartiQLParser.ColumnListContext,0)
3988+
def asIdent(self):
3989+
return self.getTypedRuleContext(PartiQLParser.AsIdentContext,0)
39903990

39913991

39923992
def getRuleIndex(self):
@@ -4029,13 +4029,13 @@ def insertStatement(self):
40294029
self.state = 525
40304030
self._errHandler.sync(self)
40314031
_la = self._input.LA(1)
4032-
if _la==10:
4032+
if _la==294:
40334033
self.state = 524
4034-
self.asIdent()
4034+
self.columnList()
40354035

40364036

40374037
self.state = 527
4038-
localctx.value = self.expr()
4038+
self.values()
40394039
self.state = 529
40404040
self._errHandler.sync(self)
40414041
_la = self._input.LA(1)
@@ -4057,13 +4057,13 @@ def insertStatement(self):
40574057
self.state = 535
40584058
self._errHandler.sync(self)
40594059
_la = self._input.LA(1)
4060-
if _la==294:
4060+
if _la==10:
40614061
self.state = 534
4062-
self.columnList()
4062+
self.asIdent()
40634063

40644064

40654065
self.state = 537
4066-
self.values()
4066+
localctx.value = self.expr()
40674067
self.state = 539
40684068
self._errHandler.sync(self)
40694069
_la = self._input.LA(1)

pymongosql/sql/partiql/PartiQLParserListener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated from PartiQLParser.g4 by ANTLR 4.13.2
1+
# Generated from PartiQLParser.g4 by ANTLR 4.13.0
22
from antlr4 import *
33
if "." in __name__:
44
from .PartiQLParser import PartiQLParser

pymongosql/sql/partiql/PartiQLParserVisitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated from PartiQLParser.g4 by ANTLR 4.13.2
1+
# Generated from PartiQLParser.g4 by ANTLR 4.13.0
22
from antlr4 import *
33
if "." in __name__:
44
from .PartiQLParser import PartiQLParser

0 commit comments

Comments
 (0)