Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Small Fixes #151

Merged
merged 3 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion data_diff/databases/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
from .base import DEFAULT_DATETIME_PRECISION, DEFAULT_NUMERIC_PRECISION

SESSION_TIME_ZONE = None # Changed by the tests

@import_helper("oracle")
def import_oracle():
Expand Down Expand Up @@ -34,7 +35,10 @@ def __init__(self, *, host, database, thread_count, **kw):
def create_connection(self):
self._oracle = import_oracle()
try:
return self._oracle.connect(**self.kwargs)
c = self._oracle.connect(**self.kwargs)
if SESSION_TIME_ZONE:
c.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'")
return c
except Exception as e:
raise ConnectError(*e.args) from e

Expand Down
9 changes: 7 additions & 2 deletions data_diff/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .base import ThreadedDatabase, import_helper, ConnectError
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS

SESSION_TIME_ZONE = None # Changed by the tests

@import_helper("postgresql")
def import_postgresql():
Expand Down Expand Up @@ -47,13 +48,17 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
return super()._convert_db_precision_to_digits(p) - 2

def create_connection(self):
if not self._args:
self._args['host'] = None # psycopg2 requires 1+ arguments

pg = import_postgresql()
try:
c = pg.connect(**self._args)
# c.cursor().execute("SET TIME ZONE 'UTC'")
if SESSION_TIME_ZONE:
c.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'")
return c
except pg.OperationalError as e:
raise ConnectError(*e._args) from e
raise ConnectError(*e.args) from e

def quote(self, s: str):
return f'"{s}"'
Expand Down
4 changes: 4 additions & 0 deletions tests/test_database_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parameterized import parameterized

from data_diff import databases as db
from data_diff.databases import postgresql, oracle
from data_diff.utils import number_to_human
from data_diff.diff_tables import TableDiffer, TableSegment, DEFAULT_BISECTION_THRESHOLD
from .common import CONN_STRINGS, N_SAMPLES, N_THREADS, BENCHMARK, GIT_REVISION, random_table_suffix
Expand All @@ -20,6 +21,7 @@
CONNS = {k: db.connect_to_uri(v, N_THREADS) for k, v in CONN_STRINGS.items()}

CONNS[db.MySQL].query("SET @@session.time_zone='+00:00'", None)
oracle.SESSION_TIME_ZONE = postgresql.SESSION_TIME_ZONE = 'UTC'


class PaginatedTable:
Expand Down Expand Up @@ -434,6 +436,8 @@ def _insert_to_table(conn, table, values, type):
value = str(sample)
elif isinstance(sample, datetime) and isinstance(conn, (db.Presto, db.Oracle)):
value = f"timestamp '{sample}'"
elif isinstance(sample, datetime) and isinstance(conn, db.BigQuery) and type == 'datetime':
value = f"cast(timestamp '{sample}' as datetime)"
elif isinstance(sample, bytearray):
value = f"'{sample.decode()}'"
else:
Expand Down