Skip to content

Commit b5760d4

Browse files
committed
Draft for the insert script for bus dataset.
1 parent cc98fa2 commit b5760d4

File tree

2 files changed

+275
-1
lines changed

2 files changed

+275
-1
lines changed

soda-core/src/soda_core/common/sql_dialect.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,11 @@ def build_drop_table_sql(self, drop_table: DROP_TABLE | DROP_TABLE_IF_EXISTS, ad
200200
#########################################################
201201
# INSERT INTO
202202
#########################################################
203-
def build_insert_sql(self, insert: InsertSqlStatement) -> str:
203+
def build_insert_sql(self, insert: INSERT_INTO | INSERT_INTO_VIA_SELECT) -> str:
204+
if isinstance(insert, INSERT_INTO) and insert.values:
205+
for row in insert.values:
206+
assert len(insert.columns) == len(row), "The number of columns and values must match"
207+
204208
columns: str = "(" + ", ".join(self.quote_column(c) for c in insert.columns) + ") " if insert.columns else ""
205209
if insert.values:
206210
values: str = ",\n ".join(
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import datetime
2+
import logging
3+
4+
import pytest
5+
6+
try:
7+
import pandas as pd
8+
except ImportError:
9+
pass
10+
from helpers.data_source_test_helper import DataSourceTestHelper
11+
from soda_core.common.data_source_impl import DataSourceImpl
12+
from soda_core.common.data_source_results import QueryResult
13+
from soda_core.common.logging_constants import soda_logger
14+
from soda_core.common.sql_ast import (
15+
COUNT,
16+
CREATE_TABLE_COLUMN,
17+
CREATE_TABLE_IF_NOT_EXISTS,
18+
DROP_TABLE_IF_EXISTS,
19+
FROM,
20+
INSERT_INTO,
21+
LITERAL,
22+
SELECT,
23+
STAR,
24+
VALUES_ROW,
25+
DBDataType,
26+
)
27+
from soda_core.common.sql_dialect import SqlDialect
28+
29+
logger: logging.Logger = soda_logger
30+
31+
CSV_FILE_LOCATION = "Bus_Breakdown_and_Delays.csv"
32+
SCHEMA_NAME = "observability_testing"
33+
TABLE_NAME = "bus_breakdown"
34+
BATCH_SIZES = { # Depends on the database. Not all databases support very large batch inserts.
35+
"postgres": 10000,
36+
"sqlserver": 1, # For sqlserver: batch size of 1 as it does auto conversions that fail when inserting multiple rows at once.
37+
"mysql": 10000,
38+
"bigquery": 2500, # Bigquery limits the query size to 1024KB, so we need to use a smaller batch size. This takes a while to run!
39+
"snowflake": 10000,
40+
"oracle": 10000,
41+
}
42+
43+
# Map the columns to data types
44+
COLUMN_TO_DATA_TYPE_MAPPING: dict[str, DBDataType] = {
45+
"School_Year": DBDataType.TEXT,
46+
"Busbreakdown_ID": DBDataType.INTEGER,
47+
"Run_Type": DBDataType.TEXT,
48+
"Bus_No": DBDataType.TEXT,
49+
"Route_Number": DBDataType.TEXT,
50+
"Reason": DBDataType.TEXT,
51+
"Schools_Serviced": DBDataType.TEXT,
52+
"Occurred_On": DBDataType.TIMESTAMP,
53+
"Created_On": DBDataType.TIMESTAMP,
54+
"Boro": DBDataType.TEXT,
55+
"Bus_Company_Name": DBDataType.TEXT,
56+
"How_Long_Delayed": DBDataType.TEXT,
57+
"Number_Of_Students_On_The_Bus": DBDataType.INTEGER,
58+
"Has_Contractor_Notified_Schools": DBDataType.TEXT,
59+
"Has_Contractor_Notified_Parents": DBDataType.TEXT,
60+
"Have_You_Alerted_OPT": DBDataType.TEXT,
61+
"Informed_On": DBDataType.TIMESTAMP,
62+
"Incident_Number": DBDataType.TEXT,
63+
"Last_Updated_On": DBDataType.TIMESTAMP,
64+
"Breakdown_or_Running_Late": DBDataType.TEXT,
65+
"School_Age_or_PreK": DBDataType.TEXT,
66+
}
67+
68+
TIMESTAMP_COLUMNS = ["Occurred_On", "Created_On", "Informed_On", "Last_Updated_On"]
69+
INTEGER_COLUMNS = ["Busbreakdown_ID", "Number_Of_Students_On_The_Bus"]
70+
71+
72+
def convert_timestamp_to_datetime(timestamp: str) -> datetime.datetime:
73+
# The timestamp is in the format "2021/01/01 10:00:00 AM"
74+
return datetime.datetime.strptime(timestamp, "%m/%d/%Y %I:%M:%S %p")
75+
76+
77+
def convert_to_values_row(row) -> VALUES_ROW:
78+
result_list: list[LITERAL] = []
79+
80+
# First we extract all the values
81+
all_values = {
82+
"School_Year": row["School_Year"],
83+
"Busbreakdown_ID": int(row["Busbreakdown_ID"]),
84+
"Run_Type": row["Run_Type"],
85+
"Bus_No": row["Bus_No"],
86+
"Route_Number": row["Route_Number"],
87+
"Reason": row["Reason"],
88+
"Schools_Serviced": row["Schools_Serviced"],
89+
"Occurred_On": row["Occurred_On"],
90+
"Created_On": row["Created_On"],
91+
"Boro": row["Boro"],
92+
"Bus_Company_Name": row["Bus_Company_Name"],
93+
"How_Long_Delayed": row["How_Long_Delayed"],
94+
"Number_Of_Students_On_The_Bus": row["Number_Of_Students_On_The_Bus"],
95+
"Has_Contractor_Notified_Schools": row["Has_Contractor_Notified_Schools"],
96+
"Has_Contractor_Notified_Parents": row["Has_Contractor_Notified_Parents"],
97+
"Have_You_Alerted_OPT": row["Have_You_Alerted_OPT"],
98+
"Informed_On": row["Informed_On"],
99+
"Incident_Number": row["Incident_Number"],
100+
"Last_Updated_On": row["Last_Updated_On"],
101+
"Breakdown_or_Running_Late": row["Breakdown_or_Running_Late"],
102+
"School_Age_or_PreK": row["School_Age_or_PreK"],
103+
}
104+
# Then we convert the values to literals (note: order must be maintained!)
105+
for key, value in all_values.items():
106+
if pd.isnull(
107+
value
108+
): # We need to check for NaN, as pandas will convert the empty strings to NaN, which databases cannot handle -> convert to Null
109+
result_list.append(LITERAL(None))
110+
else:
111+
if key in TIMESTAMP_COLUMNS:
112+
value = convert_timestamp_to_datetime(value)
113+
elif key in INTEGER_COLUMNS:
114+
value = int(value)
115+
else:
116+
value = str(value)
117+
# Make sure that the column is of the correct type, sometimes we get errors with this.
118+
if COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.TIMESTAMP:
119+
assert isinstance(value, datetime.datetime)
120+
elif COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.INTEGER:
121+
assert isinstance(value, int)
122+
elif COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.TEXT:
123+
assert isinstance(value, str)
124+
else:
125+
raise ValueError(f"Unknown column type: {COLUMN_TO_DATA_TYPE_MAPPING[key]}")
126+
result_list.append(LITERAL(value))
127+
128+
return VALUES_ROW(result_list)
129+
130+
131+
@pytest.mark.skip(
132+
reason="This test is a hack to upload the bus breakdown dataset to the test database. It should not be considered a part of the test suite."
133+
)
134+
def test_full_create_insert_drop_ast(data_source_test_helper: DataSourceTestHelper):
135+
"""
136+
This is a very hacky way to upload a dataset (specifically the bus breakdown dataset) to a database.
137+
Figured this is the easiest way to do this quickly, as we already have the connection, sqldialect,... for each datasource.
138+
You will see some hacks in this code, such as the manual setting of the dataset_prefix to the schema name, so we can use the existing test helper.
139+
If you have the time, feel free to refactor this :).
140+
Bus dataset downloaded from: https://catalog.data.gov/dataset/bus-breakdown-and-delays
141+
142+
Note: this test requires to have pandas installed!
143+
"""
144+
145+
data_source_impl: DataSourceImpl = data_source_test_helper.data_source_impl
146+
data_source_type: str = data_source_impl.type_name
147+
sql_dialect: SqlDialect = data_source_impl.sql_dialect
148+
dataset_prefix = data_source_test_helper.dataset_prefix
149+
150+
# Create the schema
151+
dataset_prefix[data_source_impl.sql_dialect.get_schema_prefix_index()] = SCHEMA_NAME
152+
data_source_test_helper.dataset_prefix = dataset_prefix
153+
data_source_test_helper.create_test_schema_if_not_exists()
154+
155+
# Create the table
156+
my_table_name = TABLE_NAME
157+
my_table_name = sql_dialect.qualify_dataset_name(dataset_prefix, my_table_name)
158+
159+
# Drop table if exists
160+
drop_table_sql = sql_dialect.build_drop_table_sql(DROP_TABLE_IF_EXISTS(fully_qualified_table_name=my_table_name))
161+
data_source_impl.execute_update(drop_table_sql)
162+
163+
# Create the columns
164+
create_table_columns = [
165+
CREATE_TABLE_COLUMN(
166+
name="School_Year", type=COLUMN_TO_DATA_TYPE_MAPPING["School_Year"], length=255, nullable=True
167+
),
168+
CREATE_TABLE_COLUMN(name="Busbreakdown_ID", type=COLUMN_TO_DATA_TYPE_MAPPING["Busbreakdown_ID"], nullable=True),
169+
CREATE_TABLE_COLUMN(name="Run_Type", type=COLUMN_TO_DATA_TYPE_MAPPING["Run_Type"], length=255, nullable=True),
170+
CREATE_TABLE_COLUMN(name="Bus_No", type=COLUMN_TO_DATA_TYPE_MAPPING["Bus_No"], length=255, nullable=True),
171+
CREATE_TABLE_COLUMN(
172+
name="Route_Number", type=COLUMN_TO_DATA_TYPE_MAPPING["Route_Number"], length=255, nullable=True
173+
),
174+
CREATE_TABLE_COLUMN(name="Reason", type=COLUMN_TO_DATA_TYPE_MAPPING["Reason"], length=255, nullable=True),
175+
CREATE_TABLE_COLUMN(
176+
name="Schools_Serviced", type=COLUMN_TO_DATA_TYPE_MAPPING["Schools_Serviced"], length=255, nullable=True
177+
),
178+
CREATE_TABLE_COLUMN(name="Occurred_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Occurred_On"], nullable=True),
179+
CREATE_TABLE_COLUMN(name="Created_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Created_On"], nullable=True),
180+
CREATE_TABLE_COLUMN(name="Boro", type=COLUMN_TO_DATA_TYPE_MAPPING["Boro"], length=255, nullable=True),
181+
CREATE_TABLE_COLUMN(
182+
name="Bus_Company_Name", type=COLUMN_TO_DATA_TYPE_MAPPING["Bus_Company_Name"], length=255, nullable=True
183+
),
184+
CREATE_TABLE_COLUMN(
185+
name="How_Long_Delayed", type=COLUMN_TO_DATA_TYPE_MAPPING["How_Long_Delayed"], length=255, nullable=True
186+
),
187+
CREATE_TABLE_COLUMN(
188+
name="Number_Of_Students_On_The_Bus",
189+
type=COLUMN_TO_DATA_TYPE_MAPPING["Number_Of_Students_On_The_Bus"],
190+
nullable=True,
191+
),
192+
CREATE_TABLE_COLUMN(
193+
name="Has_Contractor_Notified_Schools",
194+
type=COLUMN_TO_DATA_TYPE_MAPPING["Has_Contractor_Notified_Schools"],
195+
length=255,
196+
nullable=True,
197+
),
198+
CREATE_TABLE_COLUMN(
199+
name="Has_Contractor_Notified_Parents",
200+
type=COLUMN_TO_DATA_TYPE_MAPPING["Has_Contractor_Notified_Parents"],
201+
length=255,
202+
nullable=True,
203+
),
204+
CREATE_TABLE_COLUMN(
205+
name="Have_You_Alerted_OPT",
206+
type=COLUMN_TO_DATA_TYPE_MAPPING["Have_You_Alerted_OPT"],
207+
length=255,
208+
nullable=True,
209+
),
210+
CREATE_TABLE_COLUMN(name="Informed_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Informed_On"], nullable=True),
211+
CREATE_TABLE_COLUMN(
212+
name="Incident_Number", type=COLUMN_TO_DATA_TYPE_MAPPING["Incident_Number"], length=255, nullable=True
213+
),
214+
CREATE_TABLE_COLUMN(name="Last_Updated_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Last_Updated_On"], nullable=True),
215+
CREATE_TABLE_COLUMN(
216+
name="Breakdown_or_Running_Late",
217+
type=COLUMN_TO_DATA_TYPE_MAPPING["Breakdown_or_Running_Late"],
218+
length=255,
219+
nullable=True,
220+
),
221+
CREATE_TABLE_COLUMN(
222+
name="School_Age_or_PreK", type=COLUMN_TO_DATA_TYPE_MAPPING["School_Age_or_PreK"], length=255, nullable=True
223+
),
224+
]
225+
226+
standard_columns = [column.convert_to_standard_column() for column in create_table_columns]
227+
228+
# First create the table
229+
create_table_sql = sql_dialect.build_create_table_sql(
230+
CREATE_TABLE_IF_NOT_EXISTS(
231+
fully_qualified_table_name=my_table_name,
232+
columns=create_table_columns,
233+
)
234+
)
235+
data_source_impl.execute_update(create_table_sql)
236+
237+
# Read the csv file into a pandas dataframe
238+
logger.info("Reading the csv file into a pandas dataframe")
239+
df = pd.read_csv(CSV_FILE_LOCATION, index_col=False)
240+
# Convert the dataframe to a list of values rows
241+
# We can speed this up with pandarallel if needed.
242+
logger.info("Converting the dataframe to a list of values rows")
243+
values_rows = df.apply(lambda x: convert_to_values_row(x), axis=1)
244+
logger.info(f"Number of values rows: {len(values_rows)}")
245+
246+
# Then insert into the table
247+
# We do this in batches. For some databases we get errors if we do everything at once.
248+
batch_size = BATCH_SIZES[data_source_type]
249+
for i in range(0, len(values_rows), batch_size):
250+
batch_values_rows = values_rows[i : i + batch_size]
251+
insert_into_sql = sql_dialect.build_insert_into_sql(
252+
INSERT_INTO(
253+
fully_qualified_table_name=my_table_name,
254+
values=batch_values_rows,
255+
columns=standard_columns,
256+
)
257+
)
258+
logger.info(f"Executing the insert into sql for batch {i//batch_size + 1} of {len(values_rows)//batch_size}")
259+
data_source_impl.execute_update(insert_into_sql)
260+
261+
# Build a select count star query to verify the rows inserted
262+
select_star_query = sql_dialect.build_select_sql(
263+
[SELECT(COUNT(STAR())), FROM(my_table_name[1:-1])] # Remove the outer quotes, as the table will be quoted again
264+
)
265+
select_star_result: QueryResult = data_source_impl.execute_query(select_star_query)
266+
logger.info("Verifying that the number of rows inserted is correct")
267+
logger.info(f"Select star result: {select_star_result}")
268+
assert select_star_result.rows[0][0] == len(values_rows)
269+
assert select_star_result.rows[0][0] == len(df)
270+
logger.info("Successfully uploaded the bus breakdown dataset to the database")

0 commit comments

Comments
 (0)