Skip to content

Commit 48cb78e

Browse files
authored
Merge pull request taoyds#6 from ElementAI/faster-sql-execution
use asyncio instead of subprocess to limit execution time
2 parents e9e4372 + 14784c3 commit 48cb78e

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

exec_eval.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
import re
3+
import asyncio
4+
import sqlite3
25
import threading
36
from typing import Tuple, Any, List, Set
47
from itertools import product
@@ -119,39 +122,49 @@ def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -
119122
return False
120123

121124

122-
def clean_tmp_f(f_prefix: str):
123-
with threadLock:
124-
for suffix in ('.in', '.out'):
125-
f_path = f_prefix + suffix
126-
if os.path.exists(f_path):
127-
os.unlink(f_path)
125+
def replace_cur_year(query: str) -> str:
126+
return re.sub(
127+
"YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE
128+
)
128129

129130

130-
# we need a wrapper, because simple timeout will not stop the database connection
131-
def exec_on_db(sqlite_path: str, query: str, process_id: str = '', timeout: int = TIMEOUT) -> Tuple[str, Any]:
132-
f_prefix = None
133-
with threadLock:
134-
while f_prefix is None or os.path.exists(f_prefix + '.in'):
135-
process_id += str(time.time())
136-
process_id += str(random.randint(0, 10000000000))
137-
f_prefix = os.path.join(EXEC_TMP_DIR, process_id)
138-
pkl.dump((sqlite_path, query), open(f_prefix + '.in', 'wb'))
131+
# get the database cursor for a sqlite database path
132+
def get_cursor_from_path(sqlite_path: str):
139133
try:
140-
subprocess.call(['python3', 'exec_subprocess.py', f_prefix], timeout=timeout, stderr=open('runerr.log', 'a'))
134+
if not os.path.exists(sqlite_path):
135+
print("Openning a new connection %s" % sqlite_path)
136+
connection = sqlite3.connect(sqlite_path)
141137
except Exception as e:
142-
print(e)
143-
clean_tmp_f(f_prefix)
144-
return 'exception', e
145-
result_path = f_prefix + '.out'
146-
returned_val = ('exception', TimeoutError)
138+
print(sqlite_path)
139+
raise e
140+
connection.text_factory = lambda b: b.decode(errors="ignore")
141+
cursor = connection.cursor()
142+
return cursor
143+
144+
145+
async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]:
146+
query = replace_cur_year(query)
147+
cursor = get_cursor_from_path(sqlite_path)
147148
try:
148-
if os.path.exists(result_path):
149-
returned_val = pkl.load(open(result_path, 'rb'))
150-
except:
151-
pass
152-
clean_tmp_f(f_prefix)
153-
return returned_val
149+
cursor.execute(query)
150+
result = cursor.fetchall()
151+
cursor.close()
152+
cursor.connection.close()
153+
return "result", result
154+
except Exception as e:
155+
cursor.close()
156+
cursor.connection.close()
157+
return "exception", e
154158

159+
async def exec_on_db(
160+
sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT
161+
) -> Tuple[str, Any]:
162+
try:
163+
return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout)
164+
except asyncio.TimeoutError:
165+
return ('exception', TimeoutError)
166+
except Exception as e:
167+
return ("exception", e)
155168

156169

157170
# postprocess the model predictions to avoid execution errors
@@ -208,8 +221,8 @@ def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_dist
208221
ranger = db_paths
209222

210223
for db_path in ranger:
211-
g_flag, g_denotation = exec_on_db(db_path, g_str)
212-
p_flag, p_denotation = exec_on_db(db_path, pred)
224+
g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str))
225+
p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred))
213226

214227
# we should expect the gold to be succesfully executed on the database
215228
assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path)

0 commit comments

Comments
 (0)