Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Adds support for Numeric types with arbitrary precision #74

Merged
merged 12 commits into from
Jun 21, 2022
Merged
189 changes: 175 additions & 14 deletions data_diff/database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from functools import lru_cache
from itertools import zip_longest
import re
Expand Down Expand Up @@ -63,6 +64,7 @@ def import_presto():
class ConnectError(Exception):
pass


class QueryError(Exception):
pass

Expand Down Expand Up @@ -105,6 +107,26 @@ class Datetime(TemporalType):
pass


@dataclass
class NumericType(ColType):
# 'precision' signifies how many fractional digits (after the dot) we want to compare
precision: int


class Float(NumericType):
pass


class Decimal(NumericType):
pass


@dataclass
class Integer(Decimal):
def __post_init__(self):
assert self.precision == 0


@dataclass
class UnknownColType(ColType):
text: str
Expand Down Expand Up @@ -162,6 +184,19 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str:

Rounded up/down according to coltype.rounds

- Floats/Decimals are expected in the format
"I.P"

Where I is the integer part of the number (as many digits as necessary),
and must be at least one digit (0).
P is the fractional digits, the amount of which is specified with
coltype.precision. Trailing zeroes may be necessary.

Note: This precision is different than the one used by databases. For decimals,
it's the same as "numeric_scale", and for floats, who use binary precision,
it can be calculated as log10(2**p)


"""
...

Expand Down Expand Up @@ -212,23 +247,48 @@ def query(self, sql_ast: SqlOrStr, res_type: type):
def enable_interactive(self):
self._interactive = True

def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
def _convert_db_precision_to_digits(self, p: int) -> int:
"""Convert from binary precision, used by floats, to decimal precision."""
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
return math.floor(math.log(2**p, 10))

def _parse_type(
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
) -> ColType:
""" """

cls = self.DATETIME_TYPES.get(type_repr)
if cls:
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION,
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

cls = self.NUMERIC_TYPES.get(type_repr)
if cls:
if issubclass(cls, Integer):
# Some DBs have a constant numeric_scale, so they don't report it.
# We fill in the constant, so we need to ignore it for integers.
return cls(precision=0)

elif issubclass(cls, Decimal):
return cls(precision=numeric_scale)

assert issubclass(cls, Float)
# assert numeric_scale is None
return cls(
precision=self._convert_db_precision_to_digits(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Postgres, numeric_precision in INFORMATION_SCHEMA.columns is 24, not 23 as I'd normally expect for a 4-byte float. Do you know why?

This causes a potential problem... For Postgres at least, the precision this function gives, and the documentation, aren't congruent for 4-byte, but is for 8-byte:

math.floor(math.log(2**24, 10))
# => 7

>>> math.floor(math.log(2**53, 10))
15

CleanShot 2022-06-21 at 09 49 21@2x

https://www.postgresql.org/docs/current/datatype-numeric.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea why.. Double still gives 53..

numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is 4-bytes float the safe default? Because even if it's 8-bytes, it is "at least" 4-bytes?

Isn't this only a safe default if it's a float4 and numeric_precision isn't in the schema? This seems like code that should be in the driver it's relevant for, not the abstract class, no?

Should we maybe introduce Double in addition, so they can be typed out in NUMERIC_TYPES for each class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

not the abstract class

See the refactor PR, it addresses this concern.

Should we maybe introduce Double in addition

It isn't necessary for most DBs, because they provide numeric_scale.

For those that don't, yes, it can give us more accuracy. It just doesn't seem like high priority.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 I think it's worth a comment why this is a safe default, but yeah, I agree with you

)
)

return UnknownColType(type_repr)

def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

return (
"SELECT column_name, data_type, datetime_precision, numeric_precision FROM information_schema.columns "
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

Expand All @@ -250,7 +310,9 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
elif len(path) == 2:
return path

raise ValueError(f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table")
raise ValueError(
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table"
)

def parse_table_name(self, name: str) -> DbPath:
return parse_table_name(name)
Expand Down Expand Up @@ -295,7 +357,8 @@ def close(self):
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2
CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1

DEFAULT_PRECISION = 6
DEFAULT_DATETIME_PRECISION = 6
DEFAULT_NUMERIC_PRECISION = 24

TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20

Expand All @@ -307,6 +370,13 @@ class Postgres(ThreadedDatabase):
"timestamp": Timestamp,
# "datetime": Datetime,
}
NUMERIC_TYPES = {
"double precision": Float,
"real": Float,
"decimal": Decimal,
"integer": Integer,
"numeric": Decimal,
}
ROUNDS_ON_PREC_LOSS = True

default_schema = "public"
Expand Down Expand Up @@ -351,6 +421,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"

elif isinstance(coltype, NumericType):
value = f"{value}::decimal(38, {coltype.precision})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can you assume 38 digits in front of the .?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the maximum.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment would be sweet.


return self.to_string(f"{value}")


Expand All @@ -362,6 +435,10 @@ class Presto(Database):
"timestamp": Timestamp,
# "datetime": Datetime,
}
NUMERIC_TYPES = {
"integer": Integer,
"real": Float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're missing double-precision here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

}
ROUNDS_ON_PREC_LOSS = True

def __init__(self, host, port, user, password, *, catalog, schema=None, **kw):
Expand Down Expand Up @@ -401,6 +478,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
)

elif isinstance(coltype, NumericType):
value = f"cast({value} as decimal(38,{coltype.precision}))"

return self.to_string(value)

def select_table_schema(self, path: DbPath) -> str:
Expand All @@ -422,8 +502,24 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr
if m:
datetime_precision = int(m.group(1))
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, rounds=False
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=False,
)

cls = self.NUMERIC_TYPES.get(type_repr)
if cls:
if issubclass(cls, Integer):
assert numeric_precision is not None
return cls(0)
elif issubclass(cls, Decimal):
return cls(6)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this hardcoded for Decimal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presto doesn't provide the numeric_scale. 6 is just a safe value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the way it's handled.


assert issubclass(cls, Float)
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

return UnknownColType(type_repr)

Expand All @@ -433,6 +529,12 @@ class MySQL(ThreadedDatabase):
"datetime": Datetime,
"timestamp": Timestamp,
}
NUMERIC_TYPES = {
"double": Float,
"float": Float,
"decimal": Decimal,
"int": Decimal,
}
ROUNDS_ON_PREC_LOSS = True

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
Expand Down Expand Up @@ -472,6 +574,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
s = self.to_string(f"cast({value} as datetime(6))")
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"

elif isinstance(coltype, NumericType):
value = f"cast({value} as decimal(38,{coltype.precision}))"

return self.to_string(f"{value}")


Expand Down Expand Up @@ -513,16 +618,24 @@ def select_table_schema(self, path: DbPath) -> str:
(table,) = path

return (
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision"
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale"
f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'"
)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, PrecisionType):
if isinstance(coltype, TemporalType):
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
elif isinstance(coltype, NumericType):
# FM999.9990
format_str = "FM" + "9" * (38 - coltype.precision)
if coltype.precision:
format_str += "0." + "9" * (coltype.precision - 1) + "0"
return f"to_char({value}, '{format_str}')"
return self.to_string(f"{value}")

def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType:
def _parse_type(
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None
) -> ColType:
""" """
regexps = {
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
Expand All @@ -532,14 +645,40 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr
m = re.match(regexp + "$", type_repr)
if m:
datetime_precision = int(m.group(1))
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS
return cls(
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
rounds=self.ROUNDS_ON_PREC_LOSS,
)

cls = {
"NUMBER": Decimal,
"FLOAT": Float,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about LONG?

}.get(type_repr, None)
if cls:
if issubclass(cls, Decimal):
assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale)
return cls(precision=numeric_scale)

assert issubclass(cls, Float)
return cls(
precision=self._convert_db_precision_to_digits(
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION
)
)

return UnknownColType(type_repr)


class Redshift(Postgres):
NUMERIC_TYPES = {
**Postgres.NUMERIC_TYPES,
"double": Float,
"real": Float,
}

def _convert_db_precision_to_digits(self, p: int) -> int:
return super()._convert_db_precision_to_digits(p // 2)

def md5_to_int(self, s: str) -> str:
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)"

Expand All @@ -560,6 +699,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"

elif isinstance(coltype, NumericType):
value = f"{value}::decimal(38,{coltype.precision})"

return self.to_string(f"{value}")


Expand Down Expand Up @@ -595,6 +737,14 @@ class BigQuery(Database):
"TIMESTAMP": Timestamp,
"DATETIME": Datetime,
}
NUMERIC_TYPES = {
"INT64": Integer,
"INT32": Integer,
"NUMERIC": Decimal,
"BIGNUMERIC": Decimal,
"FLOAT64": Float,
"FLOAT32": Float,
}
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation

def __init__(self, project, *, dataset, **kw):
Expand Down Expand Up @@ -640,12 +790,12 @@ def select_table_schema(self, path: DbPath) -> str:
schema, table = self._normalize_table_path(path)

return (
f"SELECT column_name, data_type, 6 as datetime_precision, 6 as numeric_precision FROM {schema}.INFORMATION_SCHEMA.COLUMNS "
f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS "
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
)

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, PrecisionType):
if isinstance(coltype, TemporalType):
if coltype.rounds:
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))"
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})"
Expand All @@ -658,6 +808,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})"
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"

elif isinstance(coltype, NumericType):
# value = f"cast({value} as decimal)"
return f"format('%.{coltype.precision}f', cast({value} as decimal))"

return self.to_string(f"{value}")

def parse_table_name(self, name: str) -> DbPath:
Expand All @@ -671,6 +825,10 @@ class Snowflake(Database):
"TIMESTAMP_LTZ": Timestamp,
"TIMESTAMP_TZ": TimestampTZ,
}
NUMERIC_TYPES = {
"NUMBER": Decimal,
"FLOAT": Float,
}
ROUNDS_ON_PREC_LOSS = False

def __init__(
Expand Down Expand Up @@ -729,14 +887,17 @@ def select_table_schema(self, path: DbPath) -> str:
return super().select_table_schema((schema, table))

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
if isinstance(coltype, PrecisionType):
if isinstance(coltype, TemporalType):
if coltype.rounds:
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))"
else:
timestamp = f"cast({value} as timestamp({coltype.precision}))"

return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')"

elif isinstance(coltype, NumericType):
value = f"cast({value} as decimal(38, {coltype.precision}))"

return self.to_string(f"{value}")


Expand Down
Loading