|
1 | 1 | import os |
| 2 | +import re |
| 3 | +import asyncio |
| 4 | +import sqlite3 |
2 | 5 | import threading |
3 | 6 | from typing import Tuple, Any, List, Set |
4 | 7 | from itertools import product |
@@ -119,39 +122,49 @@ def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) - |
119 | 122 | return False |
120 | 123 |
|
121 | 124 |
|
122 | | -def clean_tmp_f(f_prefix: str): |
123 | | - with threadLock: |
124 | | - for suffix in ('.in', '.out'): |
125 | | - f_path = f_prefix + suffix |
126 | | - if os.path.exists(f_path): |
127 | | - os.unlink(f_path) |
| 125 | +def replace_cur_year(query: str) -> str: |
| 126 | + return re.sub( |
| 127 | + "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE |
| 128 | + ) |
128 | 129 |
|
129 | 130 |
|
130 | | -# we need a wrapper, because simple timeout will not stop the database connection |
131 | | -def exec_on_db(sqlite_path: str, query: str, process_id: str = '', timeout: int = TIMEOUT) -> Tuple[str, Any]: |
132 | | - f_prefix = None |
133 | | - with threadLock: |
134 | | - while f_prefix is None or os.path.exists(f_prefix + '.in'): |
135 | | - process_id += str(time.time()) |
136 | | - process_id += str(random.randint(0, 10000000000)) |
137 | | - f_prefix = os.path.join(EXEC_TMP_DIR, process_id) |
138 | | - pkl.dump((sqlite_path, query), open(f_prefix + '.in', 'wb')) |
| 131 | +# get the database cursor for a sqlite database path |
| 132 | +def get_cursor_from_path(sqlite_path: str): |
139 | 133 | try: |
140 | | - subprocess.call(['python3', 'exec_subprocess.py', f_prefix], timeout=timeout, stderr=open('runerr.log', 'a')) |
| 134 | + if not os.path.exists(sqlite_path): |
| 135 | + print("Openning a new connection %s" % sqlite_path) |
| 136 | + connection = sqlite3.connect(sqlite_path) |
141 | 137 | except Exception as e: |
142 | | - print(e) |
143 | | - clean_tmp_f(f_prefix) |
144 | | - return 'exception', e |
145 | | - result_path = f_prefix + '.out' |
146 | | - returned_val = ('exception', TimeoutError) |
| 138 | + print(sqlite_path) |
| 139 | + raise e |
| 140 | + connection.text_factory = lambda b: b.decode(errors="ignore") |
| 141 | + cursor = connection.cursor() |
| 142 | + return cursor |
| 143 | + |
| 144 | + |
| 145 | +async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: |
| 146 | + query = replace_cur_year(query) |
| 147 | + cursor = get_cursor_from_path(sqlite_path) |
147 | 148 | try: |
148 | | - if os.path.exists(result_path): |
149 | | - returned_val = pkl.load(open(result_path, 'rb')) |
150 | | - except: |
151 | | - pass |
152 | | - clean_tmp_f(f_prefix) |
153 | | - return returned_val |
| 149 | + cursor.execute(query) |
| 150 | + result = cursor.fetchall() |
| 151 | + cursor.close() |
| 152 | + cursor.connection.close() |
| 153 | + return "result", result |
| 154 | + except Exception as e: |
| 155 | + cursor.close() |
| 156 | + cursor.connection.close() |
| 157 | + return "exception", e |
154 | 158 |
|
| 159 | +async def exec_on_db( |
| 160 | + sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT |
| 161 | +) -> Tuple[str, Any]: |
| 162 | + try: |
| 163 | + return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) |
| 164 | + except asyncio.TimeoutError: |
| 165 | + return ('exception', TimeoutError) |
| 166 | + except Exception as e: |
| 167 | + return ("exception", e) |
155 | 168 |
|
156 | 169 |
|
157 | 170 | # postprocess the model predictions to avoid execution errors |
@@ -208,8 +221,8 @@ def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_dist |
208 | 221 | ranger = db_paths |
209 | 222 |
|
210 | 223 | for db_path in ranger: |
211 | | - g_flag, g_denotation = exec_on_db(db_path, g_str) |
212 | | - p_flag, p_denotation = exec_on_db(db_path, pred) |
| 224 | + g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) |
| 225 | + p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) |
213 | 226 |
|
214 | 227 | # we should expect the gold to be succesfully executed on the database |
215 | 228 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path) |
|
0 commit comments