-
Notifications
You must be signed in to change notification settings - Fork 294
Adds support for Numeric types with arbitrary precision #74
Changes from 8 commits
616fd63
06b1b55
902542f
a91dbab
c2e8697
ff1a6d6
315c244
c7367ba
2c65adf
79670f8
154f2de
6ef0a07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import math | ||
from functools import lru_cache | ||
from itertools import zip_longest | ||
import re | ||
|
@@ -63,6 +64,7 @@ def import_presto(): | |
class ConnectError(Exception): | ||
pass | ||
|
||
|
||
class QueryError(Exception): | ||
pass | ||
|
||
|
@@ -105,6 +107,26 @@ class Datetime(TemporalType): | |
pass | ||
|
||
|
||
@dataclass | ||
class NumericType(ColType): | ||
# 'precision' signifies how many fractional digits (after the dot) we want to compare | ||
precision: int | ||
|
||
|
||
class Float(NumericType): | ||
pass | ||
|
||
|
||
class Decimal(NumericType): | ||
pass | ||
|
||
|
||
@dataclass | ||
class Integer(Decimal): | ||
def __post_init__(self): | ||
assert self.precision == 0 | ||
|
||
|
||
@dataclass | ||
class UnknownColType(ColType): | ||
text: str | ||
|
@@ -162,6 +184,19 @@ def normalize_value_by_type(value: str, coltype: ColType) -> str: | |
|
||
Rounded up/down according to coltype.rounds | ||
|
||
- Floats/Decimals are expected in the format | ||
"I.P" | ||
|
||
Where I is the integer part of the number (as many digits as necessary), | ||
and must be at least one digit (0). | ||
P is the fractional digits, the amount of which is specified with | ||
coltype.precision. Trailing zeroes may be necessary. | ||
|
||
Note: This precision is different than the one used by databases. For decimals, | ||
it's the same as "numeric_scale", and for floats, who use binary precision, | ||
it can be calculated as log10(2**p) | ||
|
||
|
||
""" | ||
... | ||
|
||
|
@@ -212,23 +247,48 @@ def query(self, sql_ast: SqlOrStr, res_type: type): | |
def enable_interactive(self): | ||
self._interactive = True | ||
|
||
def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: | ||
def _convert_db_precision_to_digits(self, p: int) -> int: | ||
"""Convert from binary precision, used by floats, to decimal precision.""" | ||
# See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format | ||
return math.floor(math.log(2**p, 10)) | ||
|
||
def _parse_type( | ||
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None | ||
) -> ColType: | ||
""" """ | ||
|
||
cls = self.DATETIME_TYPES.get(type_repr) | ||
if cls: | ||
return cls( | ||
precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, | ||
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, | ||
rounds=self.ROUNDS_ON_PREC_LOSS, | ||
) | ||
|
||
cls = self.NUMERIC_TYPES.get(type_repr) | ||
if cls: | ||
if issubclass(cls, Integer): | ||
# Some DBs have a constant numeric_scale, so they don't report it. | ||
# We fill in the constant, so we need to ignore it for integers. | ||
return cls(precision=0) | ||
|
||
elif issubclass(cls, Decimal): | ||
return cls(precision=numeric_scale) | ||
|
||
assert issubclass(cls, Float) | ||
# assert numeric_scale is None | ||
return cls( | ||
precision=self._convert_db_precision_to_digits( | ||
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is 4-bytes float the safe default? Because even if it's 8-bytes, it is "at least" 4-bytes? Isn't this only a safe default if it's a float4 and Should we maybe introduce There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes.
See the refactor PR, it addresses this concern.
It isn't necessary for most DBs, because they provide numeric_scale. For those that don't, yes, it can give us more accuracy. It just doesn't seem like high priority. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍🏻 I think it's worth a comment why this is a safe default, but yeah, I agree with you |
||
) | ||
) | ||
|
||
return UnknownColType(type_repr) | ||
|
||
def select_table_schema(self, path: DbPath) -> str: | ||
schema, table = self._normalize_table_path(path) | ||
|
||
return ( | ||
"SELECT column_name, data_type, datetime_precision, numeric_precision FROM information_schema.columns " | ||
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " | ||
f"WHERE table_name = '{table}' AND table_schema = '{schema}'" | ||
) | ||
|
||
|
@@ -250,7 +310,9 @@ def _normalize_table_path(self, path: DbPath) -> DbPath: | |
elif len(path) == 2: | ||
return path | ||
|
||
raise ValueError(f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") | ||
raise ValueError( | ||
f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table" | ||
) | ||
|
||
def parse_table_name(self, name: str) -> DbPath: | ||
return parse_table_name(name) | ||
|
@@ -295,7 +357,8 @@ def close(self): | |
_CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 | ||
CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 | ||
|
||
DEFAULT_PRECISION = 6 | ||
DEFAULT_DATETIME_PRECISION = 6 | ||
DEFAULT_NUMERIC_PRECISION = 24 | ||
|
||
TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 | ||
|
||
|
@@ -307,6 +370,13 @@ class Postgres(ThreadedDatabase): | |
"timestamp": Timestamp, | ||
# "datetime": Datetime, | ||
} | ||
NUMERIC_TYPES = { | ||
"double precision": Float, | ||
"real": Float, | ||
"decimal": Decimal, | ||
"integer": Integer, | ||
"numeric": Decimal, | ||
} | ||
ROUNDS_ON_PREC_LOSS = True | ||
|
||
default_schema = "public" | ||
|
@@ -351,6 +421,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | |
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" | ||
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" | ||
|
||
elif isinstance(coltype, NumericType): | ||
value = f"{value}::decimal(38, {coltype.precision})" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can you assume 38 digits in front of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the maximum. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment would be sweet. |
||
|
||
return self.to_string(f"{value}") | ||
|
||
|
||
|
@@ -362,6 +435,10 @@ class Presto(Database): | |
"timestamp": Timestamp, | ||
# "datetime": Datetime, | ||
} | ||
NUMERIC_TYPES = { | ||
"integer": Integer, | ||
"real": Float, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you're missing double-precision here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! |
||
} | ||
ROUNDS_ON_PREC_LOSS = True | ||
|
||
def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): | ||
|
@@ -401,6 +478,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | |
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" | ||
) | ||
|
||
elif isinstance(coltype, NumericType): | ||
value = f"cast({value} as decimal(38,{coltype.precision}))" | ||
|
||
return self.to_string(value) | ||
|
||
def select_table_schema(self, path: DbPath) -> str: | ||
|
@@ -422,8 +502,24 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr | |
if m: | ||
datetime_precision = int(m.group(1)) | ||
return cls( | ||
precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, rounds=False | ||
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, | ||
rounds=False, | ||
) | ||
|
||
cls = self.NUMERIC_TYPES.get(type_repr) | ||
if cls: | ||
if issubclass(cls, Integer): | ||
assert numeric_precision is not None | ||
return cls(0) | ||
elif issubclass(cls, Decimal): | ||
return cls(6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this hardcoded for Decimal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Presto doesn't provide the numeric_scale. 6 is just a safe value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed the way it's handled. |
||
|
||
assert issubclass(cls, Float) | ||
return cls( | ||
precision=self._convert_db_precision_to_digits( | ||
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION | ||
) | ||
) | ||
|
||
return UnknownColType(type_repr) | ||
|
||
|
@@ -433,6 +529,12 @@ class MySQL(ThreadedDatabase): | |
"datetime": Datetime, | ||
"timestamp": Timestamp, | ||
} | ||
NUMERIC_TYPES = { | ||
"double": Float, | ||
"float": Float, | ||
"decimal": Decimal, | ||
"int": Decimal, | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
ROUNDS_ON_PREC_LOSS = True | ||
|
||
def __init__(self, host, port, user, password, *, database, thread_count, **kw): | ||
|
@@ -472,6 +574,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | |
s = self.to_string(f"cast({value} as datetime(6))") | ||
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" | ||
|
||
elif isinstance(coltype, NumericType): | ||
value = f"cast({value} as decimal(38,{coltype.precision}))" | ||
|
||
return self.to_string(f"{value}") | ||
|
||
|
||
|
@@ -513,16 +618,24 @@ def select_table_schema(self, path: DbPath) -> str: | |
(table,) = path | ||
|
||
return ( | ||
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision" | ||
f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" | ||
f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" | ||
) | ||
|
||
def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | ||
if isinstance(coltype, PrecisionType): | ||
if isinstance(coltype, TemporalType): | ||
return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" | ||
elif isinstance(coltype, NumericType): | ||
# FM999.9990 | ||
format_str = "FM" + "9" * (38 - coltype.precision) | ||
if coltype.precision: | ||
format_str += "0." + "9" * (coltype.precision - 1) + "0" | ||
return f"to_char({value}, '{format_str}')" | ||
return self.to_string(f"{value}") | ||
|
||
def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: | ||
def _parse_type( | ||
self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None, numeric_scale: int = None | ||
) -> ColType: | ||
""" """ | ||
regexps = { | ||
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, | ||
|
@@ -532,14 +645,40 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr | |
m = re.match(regexp + "$", type_repr) | ||
if m: | ||
datetime_precision = int(m.group(1)) | ||
return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, | ||
rounds=self.ROUNDS_ON_PREC_LOSS | ||
return cls( | ||
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, | ||
rounds=self.ROUNDS_ON_PREC_LOSS, | ||
) | ||
|
||
cls = { | ||
"NUMBER": Decimal, | ||
"FLOAT": Float, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about |
||
}.get(type_repr, None) | ||
if cls: | ||
if issubclass(cls, Decimal): | ||
assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) | ||
return cls(precision=numeric_scale) | ||
|
||
assert issubclass(cls, Float) | ||
return cls( | ||
precision=self._convert_db_precision_to_digits( | ||
numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION | ||
) | ||
) | ||
|
||
return UnknownColType(type_repr) | ||
|
||
|
||
class Redshift(Postgres): | ||
NUMERIC_TYPES = { | ||
**Postgres.NUMERIC_TYPES, | ||
"double": Float, | ||
"real": Float, | ||
} | ||
|
||
def _convert_db_precision_to_digits(self, p: int) -> int: | ||
return super()._convert_db_precision_to_digits(p // 2) | ||
|
||
def md5_to_int(self, s: str) -> str: | ||
return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" | ||
|
||
|
@@ -560,6 +699,9 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | |
timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" | ||
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" | ||
|
||
elif isinstance(coltype, NumericType): | ||
value = f"{value}::decimal(38,{coltype.precision})" | ||
|
||
return self.to_string(f"{value}") | ||
|
||
|
||
|
@@ -595,6 +737,14 @@ class BigQuery(Database): | |
"TIMESTAMP": Timestamp, | ||
"DATETIME": Datetime, | ||
} | ||
NUMERIC_TYPES = { | ||
"INT64": Integer, | ||
"INT32": Integer, | ||
"NUMERIC": Decimal, | ||
"BIGNUMERIC": Decimal, | ||
"FLOAT64": Float, | ||
"FLOAT32": Float, | ||
} | ||
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation | ||
|
||
def __init__(self, project, *, dataset, **kw): | ||
|
@@ -640,12 +790,12 @@ def select_table_schema(self, path: DbPath) -> str: | |
schema, table = self._normalize_table_path(path) | ||
|
||
return ( | ||
f"SELECT column_name, data_type, 6 as datetime_precision, 6 as numeric_precision FROM {schema}.INFORMATION_SCHEMA.COLUMNS " | ||
f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " | ||
f"WHERE table_name = '{table}' AND table_schema = '{schema}'" | ||
) | ||
|
||
def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | ||
if isinstance(coltype, PrecisionType): | ||
if isinstance(coltype, TemporalType): | ||
if coltype.rounds: | ||
timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" | ||
return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" | ||
|
@@ -658,6 +808,10 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | |
timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" | ||
return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" | ||
|
||
elif isinstance(coltype, NumericType): | ||
# value = f"cast({value} as decimal)" | ||
return f"format('%.{coltype.precision}f', cast({value} as decimal))" | ||
|
||
return self.to_string(f"{value}") | ||
|
||
def parse_table_name(self, name: str) -> DbPath: | ||
|
@@ -671,6 +825,10 @@ class Snowflake(Database): | |
"TIMESTAMP_LTZ": Timestamp, | ||
"TIMESTAMP_TZ": TimestampTZ, | ||
} | ||
NUMERIC_TYPES = { | ||
"NUMBER": Decimal, | ||
"FLOAT": Float, | ||
erezsh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
ROUNDS_ON_PREC_LOSS = False | ||
|
||
def __init__( | ||
|
@@ -729,14 +887,17 @@ def select_table_schema(self, path: DbPath) -> str: | |
return super().select_table_schema((schema, table)) | ||
|
||
def normalize_value_by_type(self, value: str, coltype: ColType) -> str: | ||
if isinstance(coltype, PrecisionType): | ||
if isinstance(coltype, TemporalType): | ||
if coltype.rounds: | ||
timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" | ||
else: | ||
timestamp = f"cast({value} as timestamp({coltype.precision}))" | ||
|
||
return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" | ||
|
||
elif isinstance(coltype, NumericType): | ||
value = f"cast({value} as decimal(38, {coltype.precision}))" | ||
|
||
return self.to_string(f"{value}") | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Postgres,
numeric_precision
inINFORMATION_SCHEMA.columns
is 24, not 23 as I'd normally expect for a 4-byte float. Do you know why?This causes a potential problem... For Postgres at least, the precision this function gives, and the documentation, aren't congruent for 4-byte, but is for 8-byte:
https://www.postgresql.org/docs/current/datatype-numeric.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea why.. Double still gives 53..