forked from taoyds/spider
-
Notifications
You must be signed in to change notification settings - Fork 0
/
script_generate_db_from_schema.py
101 lines (88 loc) · 2.96 KB
/
script_generate_db_from_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import sqlite3
import argparse
from datasets import load_dataset
# Parse terminal input
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument(
"--level",
type=str,
default="syllable",
help="Level of processing (default: syllable)",
)
args = parser.parse_args()
# Set up level
LEVEL = args.level
# tables_dataset = load_dataset("TeeA/VinAIResearch-ViText2SQL", token=HF_TOKEN['TeeA'], name="syllable-level--tables")
tables_dataset = load_dataset(
"parquet",
data_files={
"train": f"https://huggingface.co/datasets/TeeA/VinAIResearch-ViText2SQL/resolve/main/{LEVEL}-level--tables/train-00000-of-00001.parquet"
},
token="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",
)
global_schema = {
db_id: schema
for db_id, schema in zip(
tables_dataset["train"]["db_id"], tables_dataset["train"]["schema"]
)
}
print(f"Generating {LEVEL} level dbs")
databases_dir = f"benchmark/db/{LEVEL}-level"
os.makedirs(databases_dir, exist_ok=True)
# Some database went wrong, like "academic", "yelp", "bike_1", "cre_Drama_Workshop_Groups"
for db_id, schema in global_schema.items():
os.makedirs(databases_dir + "/" + db_id, exist_ok=True)
# Create a connection to the SQLite3 database
conn = sqlite3.connect(f"{databases_dir}/{db_id}/{db_id}.sqlite")
# Create a cursor object to execute SQL commands
cur = conn.cursor()
# Execute the schema commands to create the tables
if db_id == "academic":
schema = schema.replace(
{
"syllable": ', "số lượng trích dẫn" number',
"word": ', "số_lượng trích_dẫn" number',
}[LEVEL],
"",
)
elif db_id == "yelp":
schema = schema.replace(
{
"syllable": '"id doanh nghiệp" number, ',
"word": '"id doanh_nghiệp" number, ',
}[LEVEL],
"",
)
schema = schema.replace(
{
"syllable": '"id người tiêu dùng" number, ',
"word": '"id người tiêu_dùng" number, ',
}[LEVEL],
"",
)
elif db_id == "bike_1":
schema = schema.replace(
{
"syllable": '"áp suất mực nước biển tối đa" number, ',
"word": '"áp_suất mực nước_biển tối_đa" number, ',
}[LEVEL],
"",
)
elif db_id == "cre_Drama_Workshop_Groups":
schema = schema.split(";")
schema = [
x
for x in schema
if {
"syllable": 'CREATE TABLE "khách hàng" ("id khách hàng" text',
"word": 'CREATE TABLE "khách_hàng" ("id khách_hàng" text',
}[LEVEL]
not in x
]
schema = ";".join(schema)
print(schema)
cur.executescript(schema)
# Commit the changes and close the connection
conn.commit()
conn.close()