Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Move common ABCs and types to database_types.py; Fix type annotations #98

Merged
merged 2 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 26 additions & 172 deletions data_diff/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
from itertools import zip_longest
import re
from abc import ABC, abstractmethod
from runtype import dataclass
import logging
from typing import Sequence, Tuple, Optional, List
from typing import Sequence, Tuple, Optional, List, Type
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import Dict

import dsnparse
import sys

from runtype import dataclass

from .sql import DbPath, SqlOrStr, Compiler, Explain, Select
from .database_types import *


logger = logging.getLogger("database")
Expand Down Expand Up @@ -109,149 +110,6 @@ def _query_conn(conn, sql_code: str) -> list:
return c.fetchall()


class ColType:
pass


@dataclass
class PrecisionType(ColType):
precision: Optional[int]
rounds: bool


class TemporalType(PrecisionType):
pass


class Timestamp(TemporalType):
pass


class TimestampTZ(TemporalType):
pass


class Datetime(TemporalType):
pass


@dataclass
class NumericType(ColType):
# 'precision' signifies how many fractional digits (after the dot) we want to compare
precision: int


class Float(NumericType):
pass


class Decimal(NumericType):
pass


@dataclass
class Integer(Decimal):
def __post_init__(self):
assert self.precision == 0


@dataclass
class UnknownColType(ColType):
text: str


class AbstractDatabase(ABC):
@abstractmethod
def quote(self, s: str):
"Quote SQL name (implementation specific)"
...

@abstractmethod
def to_string(self, s: str) -> str:
"Provide SQL for casting a column to string"
...

@abstractmethod
def md5_to_int(self, s: str) -> str:
"Provide SQL for computing md5 and returning an int"
...

@abstractmethod
def _query(self, sql_code: str) -> list:
"Send query to database and return result"
...

@abstractmethod
def select_table_schema(self, path: DbPath) -> str:
"Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"
...

@abstractmethod
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
"Query the table for its schema for table in 'path', and return {column: type}"
...

@abstractmethod
def parse_table_name(self, name: str) -> DbPath:
"Parse the given table name into a DbPath"
...

@abstractmethod
def close(self):
"Close connection(s) to the database instance. Querying will stop functioning."
...

@abstractmethod
def normalize_timestamp(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized timestamp.

The returned expression must accept any SQL datetime/timestamp, and return a string.

Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF"

Precision of dates should be rounded up/down according to coltype.rounds
"""
...

@abstractmethod
def normalize_number(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized number.

The returned expression must accept any SQL int/numeric/float, and return a string.

- Floats/Decimals are expected in the format
"I.P"

Where I is the integer part of the number (as many digits as necessary),
and must be at least one digit (0).
P is the fractional digits, the amount of which is specified with
coltype.precision. Trailing zeroes may be necessary.
If P is 0, the dot is omitted.

Note: This precision is different than the one used by databases. For decimals,
it's the same as ``numeric_scale``, and for floats, who use binary precision,
it can be calculated as ``log10(2**numeric_precision)``.
"""
...

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

The returned expression must accept any SQL value, and return a string.

The default implementation dispatches to a method according to ``coltype``:

TemporalType -> normalize_timestamp()
NumericType -> normalize_number()
-else- -> to_string()

"""
if isinstance(coltype, TemporalType):
return self.normalize_timestamp(value, coltype)
elif isinstance(coltype, NumericType):
return self.normalize_number(value, coltype)
return self.to_string(f"{value}")


class Database(AbstractDatabase):
"""Base abstract class for databases.
Expand All @@ -261,8 +119,8 @@ class Database(AbstractDatabase):
Instanciated using :meth:`~data_diff.connect_to_uri`
"""

DATETIME_TYPES = {}
default_schema = None
DATETIME_TYPES: Dict[str, type] = {}
default_schema: str = None

@property
def name(self):
Expand Down Expand Up @@ -412,9 +270,6 @@ def _query_in_worker(self, sql_code: str):
raise self._init_error
return _query_conn(self.thread_local.conn, sql_code)

def close(self):
self._queue.shutdown(True)

@abstractmethod
def create_connection(self):
...
Expand Down Expand Up @@ -481,7 +336,7 @@ def md5_to_int(self, s: str) -> str:
def to_string(self, s: str):
return f"{s}::varchar"

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')"

Expand All @@ -490,7 +345,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"{value}::decimal(38, {coltype.precision})")


Expand Down Expand Up @@ -531,7 +386,7 @@ def _query(self, sql_code: str) -> list:
def close(self):
self._conn.close()

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
# TODO
if coltype.rounds:
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
Expand All @@ -540,7 +395,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:

return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")

def select_table_schema(self, path: DbPath) -> str:
Expand All @@ -554,11 +409,11 @@ def select_table_schema(self, path: DbPath) -> str:
def _parse_type(
self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None
) -> ColType:
regexps = {
timestamp_regexps = {
r"timestamp\((\d)\)": Timestamp,
r"timestamp\((\d)\) with time zone": TimestampTZ,
}
for regexp, cls in regexps.items():
for regexp, cls in timestamp_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
Expand All @@ -567,8 +422,8 @@ def _parse_type(
rounds=False,
)

regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
for regexp, cls in regexps.items():
number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
for regexp, cls in number_regexps.items():
m = re.match(regexp + "$", type_repr)
if m:
prec, scale = map(int, m.groups())
Expand Down Expand Up @@ -632,14 +487,14 @@ def md5_to_int(self, s: str) -> str:
def to_string(self, s: str):
return f"cast({s} as char)"

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))")

s = self.to_string(f"cast({value} as datetime(6))")
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")


Expand Down Expand Up @@ -685,10 +540,10 @@ def select_table_schema(self, path: DbPath) -> str:
f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'"
)

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
# FM999.9990
format_str = "FM" + "9" * (38 - coltype.precision)
if coltype.precision:
Expand Down Expand Up @@ -749,7 +604,7 @@ class Redshift(PostgreSQL):
def md5_to_int(self, s: str) -> str:
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"{value}::timestamp(6)"
# Get seconds since epoch. Redshift doesn't support milli- or micro-seconds.
Expand All @@ -769,7 +624,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"{value}::decimal(38,{coltype.precision})")

def select_table_schema(self, path: DbPath) -> str:
Expand Down Expand Up @@ -870,7 +725,7 @@ def select_table_schema(self, path: DbPath) -> str:
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
Expand All @@ -885,7 +740,7 @@ def normalize_timestamp(self, value: str, coltype: ColType) -> str:
f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
if isinstance(coltype, Integer):
return self.to_string(value)
return f"format('%.{coltype.precision}f', {value})"
Expand Down Expand Up @@ -962,21 +817,21 @@ def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)
return super().select_table_schema((schema, table))

def normalize_timestamp(self, value: str, coltype: ColType) -> str:
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
if coltype.rounds:
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
else:
timestamp = f"cast({value} as timestamp({coltype.precision}))"

return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"

def normalize_number(self, value: str, coltype: ColType) -> str:
def normalize_number(self, value: str, coltype: NumericType) -> str:
return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))")


@dataclass
class MatchUriPath:
database_cls: type
database_cls: Type[Database]
params: List[str]
kwparams: List[str] = []
help_str: str
Expand Down Expand Up @@ -1027,7 +882,7 @@ def match_path(self, dsn):
"postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://<user>:<pass>@<host>/<database>"),
"mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://<user>:<pass>@<host>/<database>"),
"oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://<user>:<pass>@<host>/<database>"),
"mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
# "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
"redshift": MatchUriPath(Redshift, ["database?"], help_str="redshift://<user>:<pass>@<host>/<database>"),
"snowflake": MatchUriPath(
Snowflake,
Expand Down Expand Up @@ -1055,7 +910,6 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
Supported schemes:
- postgresql
- mysql
- mssql
- oracle
- snowflake
- bigquery
Expand Down
Loading