Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

revert databricks information_schema #782

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 30 additions & 39 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,47 +139,38 @@ def create_connection(self):
raise ConnectionError(*e.args) from e

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
# So, to obtain information about schema, we should use another approach.

conn = self.create_connection()
table_schema = {}

try:
table_schema = super().query_table_schema(path)
except:
logging.warning("Failed to get schema from information_schema, falling back to legacy approach.")

if not table_schema:
# This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255)
# and not the expected VARCHAR

# I don't think we'll fall back to this approach, but if so, see above
catalog, schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(table_schema) == len(rows)
return table_schema
else:
return table_schema

def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)

return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)
catalog, schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(d) == len(rows)
return d

# def select_table_schema(self, path: DbPath) -> str:
# """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
# database, schema, name = self._normalize_table_path(path)
# info_schema_path = ["information_schema", "columns"]
# if database:
# info_schema_path.insert(0, database)

# return (
# "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
# f"FROM {'.'.join(info_schema_path)} "
# f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
# )

def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
Expand Down