Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Follow-up the sqeleton-to-datadiff embedding #543

Merged
merged 15 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions data_diff/sqeleton/abcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
PrecisionType,
StringType,
Boolean,
JSONType,
)
from .compiler import AbstractCompiler, Compilable
16 changes: 16 additions & 0 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,22 @@ class Text(StringType):
supported = False


class JSONType(ColType):
pass


class RedShiftSuper(JSONType):
pass


class PostgresqlJSON(JSONType):
pass


class PostgresqlJSONB(JSONType):
pass


@dataclass
class Integer(NumericType, IKey):
precision: int = 0
Expand Down
8 changes: 7 additions & 1 deletion data_diff/sqeleton/abcs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID
from .database_types import TemporalType, FractionalType, ColType_UUID, Boolean, ColType, String_UUID, JSONType
from .compiler import Compilable


Expand Down Expand Up @@ -49,6 +49,10 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
return f"TRIM({value})"
return self.to_string(value)

def normalize_json(self, value: str, _coltype: JSONType) -> str:
"""Creates an SQL expression, that converts 'value' to its minified json string representation."""
raise NotImplementedError()

def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
"""Creates an SQL expression, that converts 'value' to a normalized representation.

Expand All @@ -73,6 +77,8 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
return self.normalize_uuid(value, coltype)
elif isinstance(coltype, Boolean):
return self.normalize_boolean(value, coltype)
elif isinstance(coltype, JSONType):
return self.normalize_json(value, coltype)
return self.to_string(value)


Expand Down
4 changes: 4 additions & 0 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DbTime,
DbPath,
Boolean,
JSONType
)
from ..abcs.mixins import Compilable
from ..abcs.mixins import (
Expand Down Expand Up @@ -259,6 +260,9 @@ def parse_type(
elif issubclass(cls, (Text, Native_UUID)):
return cls()

elif issubclass(cls, JSONType):
return cls()

raise TypeError(f"Parsing {type_repr} returned an unknown type '{cls}'.")

def _convert_db_precision_to_digits(self, p: int) -> int:
Expand Down
1 change: 1 addition & 0 deletions data_diff/sqeleton/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
"NCHAR": Text,
"NVARCHAR2": Text,
"VARCHAR2": Text,
"DATE": Timestamp,
}
ROUNDS_ON_PREC_LOSS = True
PLACEHOLDER_TABLE = "DUAL"
Expand Down
8 changes: 8 additions & 0 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
FractionalType,
Boolean,
Date,
PostgresqlJSON,
PostgresqlJSONB
)
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
Expand Down Expand Up @@ -49,6 +51,9 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str:
def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"{value}::int")

def normalize_json(self, value: str, _coltype: PostgresqlJSON) -> str:
return f"{value}::text"


class PostgresqlDialect(BaseDialect, Mixin_Schema):
name = "PostgreSQL"
Expand Down Expand Up @@ -76,6 +81,9 @@ class PostgresqlDialect(BaseDialect, Mixin_Schema):
"character varying": Text,
"varchar": Text,
"text": Text,
# JSON
"json": PostgresqlJSON,
"jsonb": PostgresqlJSONB,
# UUID
"uuid": Native_UUID,
# Boolean
Expand Down
53 changes: 51 additions & 2 deletions data_diff/sqeleton/databases/redshift.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from typing import List, Dict
from ..abcs.database_types import Float, TemporalType, FractionalType, DbPath, TimestampTZ
from ..abcs.database_types import (
Float,
TemporalType,
FractionalType,
DbPath,
TimestampTZ,
RedShiftSuper
)
from ..abcs.mixins import AbstractMixin_MD5
from .postgresql import (
PostgreSQL,
Expand Down Expand Up @@ -40,13 +47,18 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"{value}::decimal(38,{coltype.precision})")

def normalize_json(self, value: str, _coltype: RedShiftSuper) -> str:
return f'nvl2({value}, json_serialize({value}), NULL)'


class Dialect(PostgresqlDialect):
name = "Redshift"
TYPE_CLASSES = {
**PostgresqlDialect.TYPE_CLASSES,
"double": Float,
"real": Float,
# JSON
"super": RedShiftSuper
}
SUPPORTS_INDEXES = False

Expand Down Expand Up @@ -109,11 +121,48 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
assert len(d) == len(rows)
return d

def select_view_columns(self, path: DbPath) -> str:
_, schema, table = self._normalize_table_path(path)

return (
"""select * from pg_get_cols('{}.{}')
cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int)
""".format(schema, table)
)

def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
rows = self.query(self.select_view_columns(path), list)

if not rows:
raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")

output = {}
for r in rows:
col_name = r[2]
type_info = r[3].split('(')
base_type = type_info[0]
precision = None
scale = None

if len(type_info) > 1:
if base_type == 'numeric':
precision, scale = type_info[1][:-1].split(',')
precision = int(precision)
scale = int(scale)

out = [col_name, base_type, None, precision, scale]
output[col_name] = tuple(out)

return output

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
try:
return super().query_table_schema(path)
except RuntimeError:
return self.query_external_table_schema(path)
try:
return self.query_external_table_schema(path)
except RuntimeError:
return self.query_pg_get_cols()

def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
Expand Down