22import itertools
33import json
44import abc
5- import collections
65import os
76import random
87import string
@@ -41,21 +40,53 @@ def mute_sdk_loggers():
4140mute_sdk_loggers ()
4241
4342
44- @enum .unique
45- class StatementTypes (enum .Enum ):
46- Skipped = 'statement skipped'
47- Ok = 'statement ok'
48- Error = 'statement error'
49- Query = 'statement query'
50- StreamQuery = 'statement stream query'
51- ImportTableData = 'statement import table data'
43+ class StatementDefinition :
44+ @enum .unique
45+ class Type (enum .Enum ):
46+ Skipped = 'statement skipped'
47+ Ok = 'statement ok'
48+ Error = 'statement error'
49+ Query = 'statement query'
50+ StreamQuery = 'statement stream query'
51+ ImportTableData = 'statement import table data'
5252
53+ def __init__ (self , suite : str , at_line : int , type : Type , text : [str ]):
54+ self .suite_name = suite
55+ self .at_line = at_line
56+ self .s_type = type
57+ self .text = text
5358
54- def get_statement_type (line ):
55- for s_type in list (StatementTypes ):
56- if s_type .value in line .lower ():
57- return s_type
58- raise RuntimeError ("Can't find statement type for line %s" % line )
59+ def __str__ (self ):
60+ return f'''StatementDefinition:
61+ suite: { self .suite_name }
62+ line: { self .at_line }
63+ type: { self .s_type }
64+ text:
65+ ''' + '\n ' .join ([f' { row } ' for row in self .text .split ('\n ' )])
66+
67+ @staticmethod
68+ def _parse_statement_type (statement_line : str ) -> Type :
69+ for t in list (StatementDefinition .Type ):
70+ if t .value in statement_line .lower ():
71+ return t
72+ return None
73+
74+ @staticmethod
75+ def parse (suite : str , at_line : int , lines : list [str ]):
76+ if not lines or not lines [0 ]:
77+ raise RuntimeError (f'Invalid statement in { suite } , at line: { at_line } ' )
78+ type = StatementDefinition ._parse_statement_type (lines [0 ])
79+ if type is None :
80+ raise RuntimeError (f'Unknown statement type in { suite } , at line: { at_line } ' )
81+ lines .pop (0 )
82+ at_line += 1
83+ statement_lines = []
84+ for line in lines :
85+ if line .startswith ('side effect: ' ): # side effects are not supported yet
86+ pass
87+ else :
88+ statement_lines .append (line )
89+ return StatementDefinition (suite , at_line , type , "\n " .join (statement_lines ))
5990
6091
6192def get_token (length = 10 ):
@@ -67,12 +98,6 @@ def get_source_path(*args):
6798 return os .path .join (arcadia_root , test_source_path (os .path .join (* args )))
6899
69100
70- def is_empty_line (line ):
71- if line .split ():
72- return False
73- return True
74-
75-
76101def get_lines (suite_path ):
77102 with open (suite_path ) as reader :
78103 for line_idx , line in enumerate (reader .readlines ()):
@@ -97,79 +122,31 @@ def get_test_suites(directory):
97122 return suites
98123
99124
100- def get_single_statement (lines ):
125+ def split_by_statement (lines ):
101126 statement_lines = []
127+ statement_start_line_idx = 0
102128 for line_idx , line in lines :
103- if is_empty_line (line ):
104- statement = "\n " .join (statement_lines )
105- return statement
106- statement_lines .append (line )
107- return "\n " .join (statement_lines )
108-
109-
110- class ParsedStatement (collections .namedtuple ('ParsedStatement' , ["at_line" , "s_type" , "suite_name" , "text" ])):
111- def get_fields (self ):
112- return self ._fields
113-
114- def __str__ (self ):
115- result = ["" , "Parsed Statement" ]
116- for field in self .get_fields ():
117- value = str (getattr (self , field ))
118- if field != 'text' :
119- result .append (' ' * 4 + '%s: %s,' % (field , value ))
120- else :
121- result .append (' ' * 4 + '%s:' % field )
122- result .extend ([' ' * 8 + row for row in value .split ('\n ' )])
123- return "\n " .join (result )
129+ if line :
130+ if line .startswith ("statement " ):
131+ statement_start_line_idx = line_idx
132+ statement_lines = [line ]
133+ elif statement_lines :
134+ statement_lines .append (line )
135+ else :
136+ if statement_lines :
137+ yield (statement_start_line_idx , statement_lines )
138+ statement_lines = []
139+ if statement_lines :
140+ yield (statement_start_line_idx , statement_lines )
124141
125142
126143def get_statements (suite_path , suite_name ):
127- lines = get_lines (suite_path )
128- for line_idx , line in lines :
129- if is_empty_line (line ) or not is_statement_definition (line ):
130- # empty line or junk lines
131- continue
132- text = get_single_statement (lines )
133- yield ParsedStatement (
134- line_idx ,
135- get_statement_type (line ),
144+ for statement_start_line_idx , statement_lines in split_by_statement (get_lines (suite_path )):
145+ yield StatementDefinition .parse (
136146 suite_name ,
137- text )
138-
139-
140- def is_side_effect (statement_line ):
141- return statement_line .startswith ('side effect: ' )
142-
143-
144- def parse_side_effect (se_line ):
145- pieces = se_line .split (':' )
146- if len (pieces ) < 3 :
147- raise RuntimeError ("Invalid side effect description: %s" % se_line )
148- se_type = pieces [1 ].strip ()
149- se_description = ':' .join (pieces [2 :])
150- se_description = se_description .strip ()
151-
152- return se_type , se_description
153-
154-
155- def get_statement_and_side_effects (statement_text ):
156- statement_lines = statement_text .split ('\n ' )
157- side_effects = {}
158- filtered = []
159- for statement_line in statement_lines :
160- if not is_side_effect (statement_line ):
161- filtered .append (statement_line )
162- continue
163-
164- se_type , se_description = parse_side_effect (statement_line )
165-
166- side_effects [se_type ] = se_description
167-
168- return '\n ' .join (filtered ), side_effects
169-
170-
171- def is_statement_definition (line ):
172- return line .startswith ("statement" )
147+ statement_start_line_idx ,
148+ statement_lines ,
149+ )
173150
174151
175152def format_yql_statement (lines_or_statement , table_path_prefix ):
@@ -308,13 +285,15 @@ def assert_statement_import_table_data(self, statement):
308285
309286 def assert_statement (self , parsed_statement ):
310287 start_time = time .time ()
288+ print (f"qqqw_executing: { parsed_statement .s_type } , { parsed_statement .text } " )
289+
311290 from_type = {
312- StatementTypes .Ok : self .assert_statement_ok ,
313- StatementTypes .Query : self .assert_statement_query ,
314- StatementTypes .StreamQuery : self .assert_statement_stream_query ,
315- StatementTypes .Error : (lambda x : x ),
316- StatementTypes .ImportTableData : self .assert_statement_import_table_data ,
317- StatementTypes .Skipped : lambda x : x
291+ StatementDefinition . Type .Ok : self .assert_statement_ok ,
292+ StatementDefinition . Type .Query : self .assert_statement_query ,
293+ StatementDefinition . Type .StreamQuery : self .assert_statement_stream_query ,
294+ StatementDefinition . Type .Error : (lambda x : x ),
295+ StatementDefinition . Type .ImportTableData : self .assert_statement_import_table_data ,
296+ StatementDefinition . Type .Skipped : lambda x : x
318297 }
319298 assert_method = from_type .get (parsed_statement .s_type )
320299 assert_method (parsed_statement )
@@ -331,10 +310,8 @@ def assert_statement_ok(self, statement):
331310 )
332311
333312 def assert_statement_error (self , statement ):
334- # not supported yet
335- statement_text , side_effects = get_statement_and_side_effects (statement .text )
336313 assert_that (
337- lambda : self .execute_query (statement_text ),
314+ lambda : self .execute_query (statement . text ),
338315 raises (
339316 ydb .Error
340317 )
0 commit comments