Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 71 additions & 96 deletions ydb/tests/functional/suite_tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import itertools
import json
import abc
import collections
import os
import random
import string
Expand Down Expand Up @@ -41,21 +40,53 @@ def mute_sdk_loggers():
mute_sdk_loggers()


@enum.unique
class StatementTypes(enum.Enum):
Skipped = 'statement skipped'
Ok = 'statement ok'
Error = 'statement error'
Query = 'statement query'
StreamQuery = 'statement stream query'
ImportTableData = 'statement import table data'
class StatementDefinition:
@enum.unique
class Type(enum.Enum):
Skipped = 'statement skipped'
Ok = 'statement ok'
Error = 'statement error'
Query = 'statement query'
StreamQuery = 'statement stream query'
ImportTableData = 'statement import table data'

def __init__(self, suite: str, at_line: int, type: Type, text: [str]):
self.suite_name = suite
self.at_line = at_line
self.s_type = type
self.text = text

def get_statement_type(line):
for s_type in list(StatementTypes):
if s_type.value in line.lower():
return s_type
raise RuntimeError("Can't find statement type for line %s" % line)
def __str__(self):
return f'''StatementDefinition:
suite: {self.suite_name}
line: {self.at_line}
type: {self.s_type}
text:
''' + '\n'.join([f' {row}' for row in self.text.split('\n')])

@staticmethod
def _parse_statement_type(statement_line: str) -> Type:
for t in list(StatementDefinition.Type):
if t.value in statement_line.lower():
return t
return None

@staticmethod
def parse(suite: str, at_line: int, lines: list[str]):
if not lines or not lines[0]:
raise RuntimeError(f'Invalid statement in {suite}, at line: {at_line}')
type = StatementDefinition._parse_statement_type(lines[0])
if type is None:
raise RuntimeError(f'Unknown statement type in {suite}, at line: {at_line}')
lines.pop(0)
at_line += 1
statement_lines = []
for line in lines:
if line.startswith('side effect: '): # side effects are not supported yet
pass
else:
statement_lines.append(line)
return StatementDefinition(suite, at_line, type, "\n".join(statement_lines))


def get_token(length=10):
Expand All @@ -67,12 +98,6 @@ def get_source_path(*args):
return os.path.join(arcadia_root, test_source_path(os.path.join(*args)))


def is_empty_line(line):
if line.split():
return False
return True


def get_lines(suite_path):
with open(suite_path) as reader:
for line_idx, line in enumerate(reader.readlines()):
Expand All @@ -97,79 +122,31 @@ def get_test_suites(directory):
return suites


def get_single_statement(lines):
def split_by_statement(lines):
statement_lines = []
statement_start_line_idx = 0
for line_idx, line in lines:
if is_empty_line(line):
statement = "\n".join(statement_lines)
return statement
statement_lines.append(line)
return "\n".join(statement_lines)


class ParsedStatement(collections.namedtuple('ParsedStatement', ["at_line", "s_type", "suite_name", "text"])):
def get_fields(self):
return self._fields

def __str__(self):
result = ["", "Parsed Statement"]
for field in self.get_fields():
value = str(getattr(self, field))
if field != 'text':
result.append(' ' * 4 + '%s: %s,' % (field, value))
else:
result.append(' ' * 4 + '%s:' % field)
result.extend([' ' * 8 + row for row in value.split('\n')])
return "\n".join(result)
if line:
if line.startswith("statement "):
statement_start_line_idx = line_idx
statement_lines = [line]
elif statement_lines:
statement_lines.append(line)
else:
if statement_lines:
yield (statement_start_line_idx, statement_lines)
statement_lines = []
if statement_lines:
yield (statement_start_line_idx, statement_lines)


def get_statements(suite_path, suite_name):
lines = get_lines(suite_path)
for line_idx, line in lines:
if is_empty_line(line) or not is_statement_definition(line):
# empty line or junk lines
continue
text = get_single_statement(lines)
yield ParsedStatement(
line_idx,
get_statement_type(line),
for statement_start_line_idx, statement_lines in split_by_statement(get_lines(suite_path)):
yield StatementDefinition.parse(
suite_name,
text)


def is_side_effect(statement_line):
return statement_line.startswith('side effect: ')


def parse_side_effect(se_line):
pieces = se_line.split(':')
if len(pieces) < 3:
raise RuntimeError("Invalid side effect description: %s" % se_line)
se_type = pieces[1].strip()
se_description = ':'.join(pieces[2:])
se_description = se_description.strip()

return se_type, se_description


def get_statement_and_side_effects(statement_text):
statement_lines = statement_text.split('\n')
side_effects = {}
filtered = []
for statement_line in statement_lines:
if not is_side_effect(statement_line):
filtered.append(statement_line)
continue

se_type, se_description = parse_side_effect(statement_line)

side_effects[se_type] = se_description

return '\n'.join(filtered), side_effects


def is_statement_definition(line):
return line.startswith("statement")
statement_start_line_idx,
statement_lines,
)


def format_yql_statement(lines_or_statement, table_path_prefix):
Expand Down Expand Up @@ -309,12 +286,12 @@ def assert_statement_import_table_data(self, statement):
def assert_statement(self, parsed_statement):
start_time = time.time()
from_type = {
StatementTypes.Ok: self.assert_statement_ok,
StatementTypes.Query: self.assert_statement_query,
StatementTypes.StreamQuery: self.assert_statement_stream_query,
StatementTypes.Error: (lambda x: x),
StatementTypes.ImportTableData: self.assert_statement_import_table_data,
StatementTypes.Skipped: lambda x: x
StatementDefinition.Type.Ok: self.assert_statement_ok,
StatementDefinition.Type.Query: self.assert_statement_query,
StatementDefinition.Type.StreamQuery: self.assert_statement_stream_query,
StatementDefinition.Type.Error: (lambda x: x),
StatementDefinition.Type.ImportTableData: self.assert_statement_import_table_data,
StatementDefinition.Type.Skipped: lambda x: x
}
assert_method = from_type.get(parsed_statement.s_type)
assert_method(parsed_statement)
Expand All @@ -331,10 +308,8 @@ def assert_statement_ok(self, statement):
)

def assert_statement_error(self, statement):
# not supported yet
statement_text, side_effects = get_statement_and_side_effects(statement.text)
assert_that(
lambda: self.execute_query(statement_text),
lambda: self.execute_query(statement.text),
raises(
ydb.Error
)
Expand Down
5 changes: 2 additions & 3 deletions ydb/tests/functional/suite_tests/test_sql_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from hamcrest import assert_that, raises

from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute, get_statement_and_side_effects
from test_base import BaseSuiteRunner, get_token, get_test_suites, safe_execute

"""
This module is a specific runner of sqllogic tests. Test suites for this
Expand Down Expand Up @@ -38,8 +38,7 @@ def assert_statement_ok(self, statement):
safe_execute(lambda: self.__execute_sqlitedb(statement.text), statement)

def assert_statement_error(self, statement):
statement_text, side_effects = get_statement_and_side_effects(statement.text)
assert_that(lambda: self.__execute_sqlitedb(statement_text), raises(sqlite3.Error), str(statement))
assert_that(lambda: self.__execute_sqlitedb(statement.text), raises(sqlite3.Error), str(statement))
super(TestSQLLogic, self).assert_statement_error(statement)

def get_query_and_output(self, statement_text):
Expand Down
Loading