Skip to content

Commit 503856c

Browse files
authored
Merge f422b01 into 97167eb
2 parents 97167eb + f422b01 commit 503856c

File tree

2 files changed

+75
-99
lines changed

2 files changed

+75
-99
lines changed

ydb/tests/functional/suite_tests/test_base.py

Lines changed: 73 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import itertools
33
import json
44
import abc
5-
import collections
65
import os
76
import random
87
import string
@@ -41,21 +40,53 @@ def mute_sdk_loggers():
4140
mute_sdk_loggers()
4241

4342

44-
@enum.unique
45-
class StatementTypes(enum.Enum):
46-
Skipped = 'statement skipped'
47-
Ok = 'statement ok'
48-
Error = 'statement error'
49-
Query = 'statement query'
50-
StreamQuery = 'statement stream query'
51-
ImportTableData = 'statement import table data'
43+
class StatementDefinition:
44+
@enum.unique
45+
class Type(enum.Enum):
46+
Skipped = 'statement skipped'
47+
Ok = 'statement ok'
48+
Error = 'statement error'
49+
Query = 'statement query'
50+
StreamQuery = 'statement stream query'
51+
ImportTableData = 'statement import table data'
5252

53+
def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
54+
self.suite_name = suite
55+
self.at_line = at_line
56+
self.s_type = type
57+
self.text = text
5358

54-
def get_statement_type(line):
55-
for s_type in list(StatementTypes):
56-
if s_type.value in line.lower():
57-
return s_type
58-
raise RuntimeError("Can't find statement type for line %s" % line)
59+
def __str__(self):
60+
return f'''StatementDefinition:
61+
suite: {self.suite_name}
62+
line: {self.at_line}
63+
type: {self.s_type}
64+
text:
65+
''' + '\n'.join([f' {row}' for row in self.text.split('\n')])
66+
67+
@staticmethod
68+
def _parse_statement_type(statement_line: str) -> Type:
69+
for t in list(StatementDefinition.Type):
70+
if t.value in statement_line.lower():
71+
return t
72+
return None
73+
74+
@staticmethod
75+
def parse(suite: str, at_line: int, lines: list[str]):
76+
if not lines or not lines[0]:
77+
raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}')
78+
type = StatementDefinition._parse_statement_type(lines[0])
79+
if type is None:
80+
raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}')
81+
lines.pop(0)
82+
at_line += 1
83+
statement_lines = []
84+
for line in lines:
85+
if line.startswith('side effect: '): # side effects are not supported yet
86+
pass
87+
else:
88+
statement_lines.append(line)
89+
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))
5990

6091

6192
def get_token(length=10):
@@ -67,12 +98,6 @@ def get_source_path(*args):
6798
return os.path.join(arcadia_root, test_source_path(os.path.join(*args)))
6899

69100

70-
def is_empty_line(line):
71-
if line.split():
72-
return False
73-
return True
74-
75-
76101
def get_lines(suite_path):
77102
with open(suite_path) as reader:
78103
for line_idx, line in enumerate(reader.readlines()):
@@ -97,79 +122,31 @@ def get_test_suites(directory):
97122
return suites
98123

99124

100-
def get_single_statement(lines):
125+
def split_by_statement(lines):
101126
statement_lines = []
127+
statement_start_line_idx = 0
102128
for line_idx, line in lines:
103-
if is_empty_line(line):
104-
statement = "\n".join(statement_lines)
105-
return statement
106-
statement_lines.append(line)
107-
return "\n".join(statement_lines)
108-
109-
110-
class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])):
111-
def get_fields(self):
112-
return self._fields
113-
114-
def __str__(self):
115-
result = ["", "Parsed Statement"]
116-
for field in self.get_fields():
117-
value = str(getattr(self, field))
118-
if field != 'text':
119-
result.append(' ' * 4 + '%s: %s,' % (field, value))
120-
else:
121-
result.append(' ' * 4 + '%s:' % field)
122-
result.extend([' ' * 8 + row for row in value.split('\n')])
123-
return "\n".join(result)
129+
if line:
130+
if line.startswith("statement "):
131+
statement_start_line_idx = line_idx
132+
statement_lines = [line]
133+
elif statement_lines:
134+
statement_lines.append(line)
135+
else:
136+
if statement_lines:
137+
yield (statement_start_line_idx, statement_lines)
138+
statement_lines = []
139+
if statement_lines:
140+
yield (statement_start_line_idx, statement_lines)
124141

125142

126143
def get_statements(suite_path, suite_name):
127-
lines = get_lines(suite_path)
128-
for line_idx, line in lines:
129-
if is_empty_line(line) or not is_statement_definition(line):
130-
# empty line or junk lines
131-
continue
132-
text = get_single_statement(lines)
133-
yield ParsedStatement(
134-
line_idx,
135-
get_statement_type(line),
144+
for statement_start_line_idx, statement_lines in split_by_statement(get_lines(suite_path)):
145+
yield StatementDefinition.parse(
136146
suite_name,
137-
text)
138-
139-
140-
def is_side_effect(statement_line):
141-
return statement_line.startswith('side effect: ')
142-
143-
144-
def parse_side_effect(se_line):
145-
pieces = se_line.split(':')
146-
if len(pieces) < 3:
147-
raise RuntimeError("Invalid side effect description: %s" % se_line)
148-
se_type = pieces[1].strip()
149-
se_description = ':'.join(pieces[2:])
150-
se_description = se_description.strip()
151-
152-
return se_type, se_description
153-
154-
155-
def get_statement_and_side_effects(statement_text):
156-
statement_lines = statement_text.split('\n')
157-
side_effects = {}
158-
filtered = []
159-
for statement_line in statement_lines:
160-
if not is_side_effect(statement_line):
161-
filtered.append(statement_line)
162-
continue
163-
164-
se_type, se_description = parse_side_effect(statement_line)
165-
166-
side_effects[se_type] = se_description
167-
168-
return '\n'.join(filtered), side_effects
169-
170-
171-
def is_statement_definition(line):
172-
return line.startswith("statement")
147+
statement_start_line_idx,
148+
statement_lines,
149+
)
173150

174151

175152
def format_yql_statement(lines_or_statement, table_path_prefix):
@@ -308,13 +285,15 @@ def assert_statement_import_table_data(self, statement):
308285

309286
def assert_statement(self, parsed_statement):
310287
start_time = time.time()
288+
print(f"qqqw_executing: {parsed_statement.s_type}, {parsed_statement.text}")
289+
311290
from_type = {
312-
StatementTypes.Ok: self.assert_statement_ok,
313-
StatementTypes.Query: self.assert_statement_query,
314-
StatementTypes.StreamQuery: self.assert_statement_stream_query,
315-
StatementTypes.Error: (lambda x: x),
316-
StatementTypes.ImportTableData: self.assert_statement_import_table_data,
317-
StatementTypes.Skipped: lambda x: x
291+
StatementDefinition.Type.Ok: self.assert_statement_ok,
292+
StatementDefinition.Type.Query: self.assert_statement_query,
293+
StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query,
294+
StatementDefinition.Type.Error: (lambda x: x),
295+
StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data,
296+
StatementDefinition.Type.Skipped: lambda x: x
318297
}
319298
assert_method = from_type.get(parsed_statement.s_type)
320299
assert_method(parsed_statement)
@@ -331,10 +310,8 @@ def assert_statement_ok(self, statement):
331310
)
332311

333312
def assert_statement_error(self, statement):
334-
# not supported yet
335-
statement_text, side_effects = get_statement_and_side_effects(statement.text)
336313
assert_that(
337-
lambda: self.execute_query(statement_text),
314+
lambda: self.execute_query(statement.text),
338315
raises(
339316
ydb.Error
340317
)

ydb/tests/functional/suite_tests/test_sql_logic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from hamcrest import assert_that, raises
77

8-
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects
8+
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute
99

1010
"""
1111
This module is a specific runner of sqllogic tests. Test suites for this
@@ -38,8 +38,7 @@ def assert_statement_ok(self, statement):
3838
safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement)
3939

4040
def assert_statement_error(self, statement):
41-
statement_text, side_effects = get_statement_and_side_effects(statement.text)
42-
assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement))
41+
assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
4342
super(TestSQLLogic, self).assert_statement_error(statement)
4443

4544
def get_query_and_output(self, statement_text):

0 commit comments

Comments
 (0)