Skip to content

Commit d10f443

Browse files
authored
Merge 06e1308 into 3e28916
2 parents 3e28916 + 06e1308 commit d10f443

File tree

2 files changed

+64
-99
lines changed

2 files changed

+64
-99
lines changed

ydb/tests/functional/suite_tests/test_base.py

Lines changed: 62 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import itertools
33
import json
44
import abc
5-
import collections
65
import os
76
import random
87
import string
@@ -41,21 +40,45 @@ def mute_sdk_loggers():
4140
mute_sdk_loggers()
4241

4342

44-
@enum.unique
45-
class StatementTypes(enum.Enum):
46-
Skipped = 'statement skipped'
47-
Ok = 'statement ok'
48-
Error = 'statement error'
49-
Query = 'statement query'
50-
StreamQuery = 'statement stream query'
51-
ImportTableData = 'statement import table data'
43+
class StatementDefinition:
44+
@enum.unique
45+
class Type(enum.Enum):
46+
Skipped = 'statement skipped'
47+
Ok = 'statement ok'
48+
Error = 'statement error'
49+
Query = 'statement query'
50+
StreamQuery = 'statement stream query'
51+
ImportTableData = 'statement import table data'
5252

53+
def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
54+
self.suite_name = suite
55+
self.at_line = at_line
56+
self.s_type = type
57+
self.text = text
5358

54-
def get_statement_type(line):
55-
for s_type in list(StatementTypes):
56-
if s_type.value in line.lower():
57-
return s_type
58-
raise RuntimeError("Can't find statement type for line %s" % line)
59+
@staticmethod
60+
def _parse_statement_type(statement_line: str) -> Type:
61+
for t in list(StatementDefinition.Type):
62+
if t.value in statement_line.lower():
63+
return t
64+
return None
65+
66+
@staticmethod
67+
def parse(suite: str, at_line: int, lines: list[str]):
68+
if not lines or not lines[0]:
69+
raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}')
70+
type = StatementDefinition._parse_statement_type(lines[0])
71+
if type is None:
72+
raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}')
73+
lines.pop(0)
74+
at_line += 1
75+
statement_lines = []
76+
for line in lines:
77+
if line.startswith('side effect: '): # side effects are not supported yet
78+
pass
79+
else:
80+
statement_lines.append(line)
81+
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))
5982

6083

6184
def get_token(length=10):
@@ -67,12 +90,6 @@ def get_source_path(*args):
6790
return os.path.join(arcadia_root, test_source_path(os.path.join(*args)))
6891

6992

70-
def is_empty_line(line):
71-
if line.split():
72-
return False
73-
return True
74-
75-
7693
def get_lines(suite_path):
7794
with open(suite_path) as reader:
7895
for line_idx, line in enumerate(reader.readlines()):
@@ -97,79 +114,28 @@ def get_test_suites(directory):
97114
return suites
98115

99116

100-
def get_single_statement(lines):
117+
def split_by_empty_lines(lines):
101118
statement_lines = []
119+
statement_start_line_idx = 0
102120
for line_idx, line in lines:
103-
if is_empty_line(line):
104-
statement = "\n".join(statement_lines)
105-
return statement
106-
statement_lines.append(line)
107-
return "\n".join(statement_lines)
108-
109-
110-
class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])):
111-
def get_fields(self):
112-
return self._fields
113-
114-
def __str__(self):
115-
result = ["", "Parsed Statement"]
116-
for field in self.get_fields():
117-
value = str(getattr(self, field))
118-
if field != 'text':
119-
result.append(' ' * 4 + '%s: %s,' % (field, value))
120-
else:
121-
result.append(' ' * 4 + '%s:' % field)
122-
result.extend([' ' * 8 + row for row in value.split('\n')])
123-
return "\n".join(result)
121+
if line:
122+
if not statement_lines:
123+
statement_start_line_idx = line_idx
124+
statement_lines.append(line)
125+
elif statement_lines:
126+
yield (statement_start_line_idx, statement_lines)
127+
statement_lines = []
128+
if statement_lines:
129+
yield (statement_start_line_idx, statement_lines)
124130

125131

126132
def get_statements(suite_path, suite_name):
127-
lines = get_lines(suite_path)
128-
for line_idx, line in lines:
129-
if is_empty_line(line) or not is_statement_definition(line):
130-
# empty line or junk lines
131-
continue
132-
text = get_single_statement(lines)
133-
yield ParsedStatement(
134-
line_idx,
135-
get_statement_type(line),
133+
for statement_start_line_idx, statement_lines in split_by_empty_lines(get_lines(suite_path)):
134+
yield StatementDefinition.parse(
136135
suite_name,
137-
text)
138-
139-
140-
def is_side_effect(statement_line):
141-
return statement_line.startswith('side effect: ')
142-
143-
144-
def parse_side_effect(se_line):
145-
pieces = se_line.split(':')
146-
if len(pieces) < 3:
147-
raise RuntimeError("Invalid side effect description: %s" % se_line)
148-
se_type = pieces[1].strip()
149-
se_description = ':'.join(pieces[2:])
150-
se_description = se_description.strip()
151-
152-
return se_type, se_description
153-
154-
155-
def get_statement_and_side_effects(statement_text):
156-
statement_lines = statement_text.split('\n')
157-
side_effects = {}
158-
filtered = []
159-
for statement_line in statement_lines:
160-
if not is_side_effect(statement_line):
161-
filtered.append(statement_line)
162-
continue
163-
164-
se_type, se_description = parse_side_effect(statement_line)
165-
166-
side_effects[se_type] = se_description
167-
168-
return '\n'.join(filtered), side_effects
169-
170-
171-
def is_statement_definition(line):
172-
return line.startswith("statement")
136+
statement_start_line_idx,
137+
statement_lines,
138+
)
173139

174140

175141
def format_yql_statement(lines_or_statement, table_path_prefix):
@@ -308,13 +274,15 @@ def assert_statement_import_table_data(self, statement):
308274

309275
def assert_statement(self, parsed_statement):
310276
start_time = time.time()
277+
print(f"qqqw_executing: {parsed_statement.s_type}, {parsed_statement.text}")
278+
311279
from_type = {
312-
StatementTypes.Ok: self.assert_statement_ok,
313-
StatementTypes.Query: self.assert_statement_query,
314-
StatementTypes.StreamQuery: self.assert_statement_stream_query,
315-
StatementTypes.Error: (lambda x: x),
316-
StatementTypes.ImportTableData: self.assert_statement_import_table_data,
317-
StatementTypes.Skipped: lambda x: x
280+
StatementDefinition.Type.Ok: self.assert_statement_ok,
281+
StatementDefinition.Type.Query: self.assert_statement_query,
282+
StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query,
283+
StatementDefinition.Type.Error: (lambda x: x),
284+
StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data,
285+
StatementDefinition.Type.Skipped: lambda x: x
318286
}
319287
assert_method = from_type.get(parsed_statement.s_type)
320288
assert_method(parsed_statement)
@@ -331,10 +299,8 @@ def assert_statement_ok(self, statement):
331299
)
332300

333301
def assert_statement_error(self, statement):
334-
# not supported yet
335-
statement_text, side_effects = get_statement_and_side_effects(statement.text)
336302
assert_that(
337-
lambda: self.execute_query(statement_text),
303+
lambda: self.execute_query(statement.text),
338304
raises(
339305
ydb.Error
340306
)

ydb/tests/functional/suite_tests/test_sql_logic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from hamcrest import assert_that, raises
77

8-
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects
8+
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute
99

1010
"""
1111
This module is a specific runner of sqllogic tests. Test suites for this
@@ -38,8 +38,7 @@ def assert_statement_ok(self, statement):
3838
safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement)
3939

4040
def assert_statement_error(self, statement):
41-
statement_text, side_effects = get_statement_and_side_effects(statement.text)
42-
assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement))
41+
assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
4342
super(TestSQLLogic, self).assert_statement_error(statement)
4443

4544
def get_query_and_output(self, statement_text):

0 commit comments

Comments
 (0)