|
| 1 | +import datetime |
| 2 | +import logging |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +try: |
| 7 | + import pandas as pd |
| 8 | +except ImportError: |
| 9 | + pass |
| 10 | +from helpers.data_source_test_helper import DataSourceTestHelper |
| 11 | +from soda_core.common.data_source_impl import DataSourceImpl |
| 12 | +from soda_core.common.data_source_results import QueryResult |
| 13 | +from soda_core.common.logging_constants import soda_logger |
| 14 | +from soda_core.common.sql_ast import ( |
| 15 | + COUNT, |
| 16 | + CREATE_TABLE_COLUMN, |
| 17 | + CREATE_TABLE_IF_NOT_EXISTS, |
| 18 | + DROP_TABLE_IF_EXISTS, |
| 19 | + FROM, |
| 20 | + INSERT_INTO, |
| 21 | + LITERAL, |
| 22 | + SELECT, |
| 23 | + STAR, |
| 24 | + VALUES_ROW, |
| 25 | + DBDataType, |
| 26 | +) |
| 27 | +from soda_core.common.sql_dialect import SqlDialect |
| 28 | + |
| 29 | +logger: logging.Logger = soda_logger |
| 30 | + |
| 31 | +CSV_FILE_LOCATION = "Bus_Breakdown_and_Delays.csv" |
| 32 | +SCHEMA_NAME = "observability_testing" |
| 33 | +TABLE_NAME = "bus_breakdown" |
| 34 | +BATCH_SIZES = { # Depends on the database. Not all databases support very large batch inserts. |
| 35 | + "postgres": 10000, |
| 36 | + "sqlserver": 1, # For sqlserver: batch size of 1 as it does auto conversions that fail when inserting multiple rows at once. |
| 37 | + "mysql": 10000, |
| 38 | + "bigquery": 2500, # Bigquery limits the query size to 1024KB, so we need to use a smaller batch size. This takes a while to run! |
| 39 | + "snowflake": 10000, |
| 40 | + "oracle": 10000, |
| 41 | +} |
| 42 | + |
| 43 | +# Map the columns to data types |
| 44 | +COLUMN_TO_DATA_TYPE_MAPPING: dict[str, DBDataType] = { |
| 45 | + "School_Year": DBDataType.TEXT, |
| 46 | + "Busbreakdown_ID": DBDataType.INTEGER, |
| 47 | + "Run_Type": DBDataType.TEXT, |
| 48 | + "Bus_No": DBDataType.TEXT, |
| 49 | + "Route_Number": DBDataType.TEXT, |
| 50 | + "Reason": DBDataType.TEXT, |
| 51 | + "Schools_Serviced": DBDataType.TEXT, |
| 52 | + "Occurred_On": DBDataType.TIMESTAMP, |
| 53 | + "Created_On": DBDataType.TIMESTAMP, |
| 54 | + "Boro": DBDataType.TEXT, |
| 55 | + "Bus_Company_Name": DBDataType.TEXT, |
| 56 | + "How_Long_Delayed": DBDataType.TEXT, |
| 57 | + "Number_Of_Students_On_The_Bus": DBDataType.INTEGER, |
| 58 | + "Has_Contractor_Notified_Schools": DBDataType.TEXT, |
| 59 | + "Has_Contractor_Notified_Parents": DBDataType.TEXT, |
| 60 | + "Have_You_Alerted_OPT": DBDataType.TEXT, |
| 61 | + "Informed_On": DBDataType.TIMESTAMP, |
| 62 | + "Incident_Number": DBDataType.TEXT, |
| 63 | + "Last_Updated_On": DBDataType.TIMESTAMP, |
| 64 | + "Breakdown_or_Running_Late": DBDataType.TEXT, |
| 65 | + "School_Age_or_PreK": DBDataType.TEXT, |
| 66 | +} |
| 67 | + |
| 68 | +TIMESTAMP_COLUMNS = ["Occurred_On", "Created_On", "Informed_On", "Last_Updated_On"] |
| 69 | +INTEGER_COLUMNS = ["Busbreakdown_ID", "Number_Of_Students_On_The_Bus"] |
| 70 | + |
| 71 | + |
| 72 | +def convert_timestamp_to_datetime(timestamp: str) -> datetime.datetime: |
| 73 | + # The timestamp is in the format "2021/01/01 10:00:00 AM" |
| 74 | + return datetime.datetime.strptime(timestamp, "%m/%d/%Y %I:%M:%S %p") |
| 75 | + |
| 76 | + |
| 77 | +def convert_to_values_row(row) -> VALUES_ROW: |
| 78 | + result_list: list[LITERAL] = [] |
| 79 | + |
| 80 | + # First we extract all the values |
| 81 | + all_values = { |
| 82 | + "School_Year": row["School_Year"], |
| 83 | + "Busbreakdown_ID": int(row["Busbreakdown_ID"]), |
| 84 | + "Run_Type": row["Run_Type"], |
| 85 | + "Bus_No": row["Bus_No"], |
| 86 | + "Route_Number": row["Route_Number"], |
| 87 | + "Reason": row["Reason"], |
| 88 | + "Schools_Serviced": row["Schools_Serviced"], |
| 89 | + "Occurred_On": row["Occurred_On"], |
| 90 | + "Created_On": row["Created_On"], |
| 91 | + "Boro": row["Boro"], |
| 92 | + "Bus_Company_Name": row["Bus_Company_Name"], |
| 93 | + "How_Long_Delayed": row["How_Long_Delayed"], |
| 94 | + "Number_Of_Students_On_The_Bus": row["Number_Of_Students_On_The_Bus"], |
| 95 | + "Has_Contractor_Notified_Schools": row["Has_Contractor_Notified_Schools"], |
| 96 | + "Has_Contractor_Notified_Parents": row["Has_Contractor_Notified_Parents"], |
| 97 | + "Have_You_Alerted_OPT": row["Have_You_Alerted_OPT"], |
| 98 | + "Informed_On": row["Informed_On"], |
| 99 | + "Incident_Number": row["Incident_Number"], |
| 100 | + "Last_Updated_On": row["Last_Updated_On"], |
| 101 | + "Breakdown_or_Running_Late": row["Breakdown_or_Running_Late"], |
| 102 | + "School_Age_or_PreK": row["School_Age_or_PreK"], |
| 103 | + } |
| 104 | + # Then we convert the values to literals (note: order must be maintained!) |
| 105 | + for key, value in all_values.items(): |
| 106 | + if pd.isnull( |
| 107 | + value |
| 108 | + ): # We need to check for NaN, as pandas will convert the empty strings to NaN, which databases cannot handle -> convert to Null |
| 109 | + result_list.append(LITERAL(None)) |
| 110 | + else: |
| 111 | + if key in TIMESTAMP_COLUMNS: |
| 112 | + value = convert_timestamp_to_datetime(value) |
| 113 | + elif key in INTEGER_COLUMNS: |
| 114 | + value = int(value) |
| 115 | + else: |
| 116 | + value = str(value) |
| 117 | + # Make sure that the column is of the correct type, sometimes we get errors with this. |
| 118 | + if COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.TIMESTAMP: |
| 119 | + assert isinstance(value, datetime.datetime) |
| 120 | + elif COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.INTEGER: |
| 121 | + assert isinstance(value, int) |
| 122 | + elif COLUMN_TO_DATA_TYPE_MAPPING[key] == DBDataType.TEXT: |
| 123 | + assert isinstance(value, str) |
| 124 | + else: |
| 125 | + raise ValueError(f"Unknown column type: {COLUMN_TO_DATA_TYPE_MAPPING[key]}") |
| 126 | + result_list.append(LITERAL(value)) |
| 127 | + |
| 128 | + return VALUES_ROW(result_list) |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.skip( |
| 132 | + reason="This test is a hack to upload the bus breakdown dataset to the test database. It should not be considered a part of the test suite." |
| 133 | +) |
| 134 | +def test_full_create_insert_drop_ast(data_source_test_helper: DataSourceTestHelper): |
| 135 | + """ |
| 136 | + This is a very hacky way to upload a dataset (specifically the bus breakdown dataset) to a database. |
| 137 | + Figured this is the easiest way to do this quickly, as we already have the connection, sqldialect,... for each datasource. |
| 138 | + You will see some hacks in this code, such as the manual setting of the dataset_prefix to the schema name, so we can use the existing test helper. |
| 139 | + If you have the time, feel free to refactor this :). |
| 140 | + Bus dataset downloaded from: https://catalog.data.gov/dataset/bus-breakdown-and-delays |
| 141 | +
|
| 142 | + Note: this test requires to have pandas installed! |
| 143 | + """ |
| 144 | + |
| 145 | + data_source_impl: DataSourceImpl = data_source_test_helper.data_source_impl |
| 146 | + data_source_type: str = data_source_impl.type_name |
| 147 | + sql_dialect: SqlDialect = data_source_impl.sql_dialect |
| 148 | + dataset_prefix = data_source_test_helper.dataset_prefix |
| 149 | + |
| 150 | + # Create the schema |
| 151 | + dataset_prefix[data_source_impl.sql_dialect.get_schema_prefix_index()] = SCHEMA_NAME |
| 152 | + data_source_test_helper.dataset_prefix = dataset_prefix |
| 153 | + data_source_test_helper.create_test_schema_if_not_exists() |
| 154 | + |
| 155 | + # Create the table |
| 156 | + my_table_name = TABLE_NAME |
| 157 | + my_table_name = sql_dialect.qualify_dataset_name(dataset_prefix, my_table_name) |
| 158 | + |
| 159 | + # Drop table if exists |
| 160 | + drop_table_sql = sql_dialect.build_drop_table_sql(DROP_TABLE_IF_EXISTS(fully_qualified_table_name=my_table_name)) |
| 161 | + data_source_impl.execute_update(drop_table_sql) |
| 162 | + |
| 163 | + # Create the columns |
| 164 | + create_table_columns = [ |
| 165 | + CREATE_TABLE_COLUMN( |
| 166 | + name="School_Year", type=COLUMN_TO_DATA_TYPE_MAPPING["School_Year"], length=255, nullable=True |
| 167 | + ), |
| 168 | + CREATE_TABLE_COLUMN(name="Busbreakdown_ID", type=COLUMN_TO_DATA_TYPE_MAPPING["Busbreakdown_ID"], nullable=True), |
| 169 | + CREATE_TABLE_COLUMN(name="Run_Type", type=COLUMN_TO_DATA_TYPE_MAPPING["Run_Type"], length=255, nullable=True), |
| 170 | + CREATE_TABLE_COLUMN(name="Bus_No", type=COLUMN_TO_DATA_TYPE_MAPPING["Bus_No"], length=255, nullable=True), |
| 171 | + CREATE_TABLE_COLUMN( |
| 172 | + name="Route_Number", type=COLUMN_TO_DATA_TYPE_MAPPING["Route_Number"], length=255, nullable=True |
| 173 | + ), |
| 174 | + CREATE_TABLE_COLUMN(name="Reason", type=COLUMN_TO_DATA_TYPE_MAPPING["Reason"], length=255, nullable=True), |
| 175 | + CREATE_TABLE_COLUMN( |
| 176 | + name="Schools_Serviced", type=COLUMN_TO_DATA_TYPE_MAPPING["Schools_Serviced"], length=255, nullable=True |
| 177 | + ), |
| 178 | + CREATE_TABLE_COLUMN(name="Occurred_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Occurred_On"], nullable=True), |
| 179 | + CREATE_TABLE_COLUMN(name="Created_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Created_On"], nullable=True), |
| 180 | + CREATE_TABLE_COLUMN(name="Boro", type=COLUMN_TO_DATA_TYPE_MAPPING["Boro"], length=255, nullable=True), |
| 181 | + CREATE_TABLE_COLUMN( |
| 182 | + name="Bus_Company_Name", type=COLUMN_TO_DATA_TYPE_MAPPING["Bus_Company_Name"], length=255, nullable=True |
| 183 | + ), |
| 184 | + CREATE_TABLE_COLUMN( |
| 185 | + name="How_Long_Delayed", type=COLUMN_TO_DATA_TYPE_MAPPING["How_Long_Delayed"], length=255, nullable=True |
| 186 | + ), |
| 187 | + CREATE_TABLE_COLUMN( |
| 188 | + name="Number_Of_Students_On_The_Bus", |
| 189 | + type=COLUMN_TO_DATA_TYPE_MAPPING["Number_Of_Students_On_The_Bus"], |
| 190 | + nullable=True, |
| 191 | + ), |
| 192 | + CREATE_TABLE_COLUMN( |
| 193 | + name="Has_Contractor_Notified_Schools", |
| 194 | + type=COLUMN_TO_DATA_TYPE_MAPPING["Has_Contractor_Notified_Schools"], |
| 195 | + length=255, |
| 196 | + nullable=True, |
| 197 | + ), |
| 198 | + CREATE_TABLE_COLUMN( |
| 199 | + name="Has_Contractor_Notified_Parents", |
| 200 | + type=COLUMN_TO_DATA_TYPE_MAPPING["Has_Contractor_Notified_Parents"], |
| 201 | + length=255, |
| 202 | + nullable=True, |
| 203 | + ), |
| 204 | + CREATE_TABLE_COLUMN( |
| 205 | + name="Have_You_Alerted_OPT", |
| 206 | + type=COLUMN_TO_DATA_TYPE_MAPPING["Have_You_Alerted_OPT"], |
| 207 | + length=255, |
| 208 | + nullable=True, |
| 209 | + ), |
| 210 | + CREATE_TABLE_COLUMN(name="Informed_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Informed_On"], nullable=True), |
| 211 | + CREATE_TABLE_COLUMN( |
| 212 | + name="Incident_Number", type=COLUMN_TO_DATA_TYPE_MAPPING["Incident_Number"], length=255, nullable=True |
| 213 | + ), |
| 214 | + CREATE_TABLE_COLUMN(name="Last_Updated_On", type=COLUMN_TO_DATA_TYPE_MAPPING["Last_Updated_On"], nullable=True), |
| 215 | + CREATE_TABLE_COLUMN( |
| 216 | + name="Breakdown_or_Running_Late", |
| 217 | + type=COLUMN_TO_DATA_TYPE_MAPPING["Breakdown_or_Running_Late"], |
| 218 | + length=255, |
| 219 | + nullable=True, |
| 220 | + ), |
| 221 | + CREATE_TABLE_COLUMN( |
| 222 | + name="School_Age_or_PreK", type=COLUMN_TO_DATA_TYPE_MAPPING["School_Age_or_PreK"], length=255, nullable=True |
| 223 | + ), |
| 224 | + ] |
| 225 | + |
| 226 | + standard_columns = [column.convert_to_standard_column() for column in create_table_columns] |
| 227 | + |
| 228 | + # First create the table |
| 229 | + create_table_sql = sql_dialect.build_create_table_sql( |
| 230 | + CREATE_TABLE_IF_NOT_EXISTS( |
| 231 | + fully_qualified_table_name=my_table_name, |
| 232 | + columns=create_table_columns, |
| 233 | + ) |
| 234 | + ) |
| 235 | + data_source_impl.execute_update(create_table_sql) |
| 236 | + |
| 237 | + # Read the csv file into a pandas dataframe |
| 238 | + logger.info("Reading the csv file into a pandas dataframe") |
| 239 | + df = pd.read_csv(CSV_FILE_LOCATION, index_col=False) |
| 240 | + # Convert the dataframe to a list of values rows |
| 241 | + # We can speed this up with pandarallel if needed. |
| 242 | + logger.info("Converting the dataframe to a list of values rows") |
| 243 | + values_rows = df.apply(lambda x: convert_to_values_row(x), axis=1) |
| 244 | + logger.info(f"Number of values rows: {len(values_rows)}") |
| 245 | + |
| 246 | + # Then insert into the table |
| 247 | + # We do this in batches. For some databases we get errors if we do everything at once. |
| 248 | + batch_size = BATCH_SIZES[data_source_type] |
| 249 | + for i in range(0, len(values_rows), batch_size): |
| 250 | + batch_values_rows = values_rows[i : i + batch_size] |
| 251 | + insert_into_sql = sql_dialect.build_insert_into_sql( |
| 252 | + INSERT_INTO( |
| 253 | + fully_qualified_table_name=my_table_name, |
| 254 | + values=batch_values_rows, |
| 255 | + columns=standard_columns, |
| 256 | + ) |
| 257 | + ) |
| 258 | + logger.info(f"Executing the insert into sql for batch {i//batch_size + 1} of {len(values_rows)//batch_size}") |
| 259 | + data_source_impl.execute_update(insert_into_sql) |
| 260 | + |
| 261 | + # Build a select count star query to verify the rows inserted |
| 262 | + select_star_query = sql_dialect.build_select_sql( |
| 263 | + [SELECT(COUNT(STAR())), FROM(my_table_name[1:-1])] # Remove the outer quotes, as the table will be quoted again |
| 264 | + ) |
| 265 | + select_star_result: QueryResult = data_source_impl.execute_query(select_star_query) |
| 266 | + logger.info("Verifying that the number of rows inserted is correct") |
| 267 | + logger.info(f"Select star result: {select_star_result}") |
| 268 | + assert select_star_result.rows[0][0] == len(values_rows) |
| 269 | + assert select_star_result.rows[0][0] == len(df) |
| 270 | + logger.info("Successfully uploaded the bus breakdown dataset to the database") |
0 commit comments