Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 3fa2d5f

Browse files
authored
Merge pull request #338 from datafold/dec6
Small refactor in __main__ and tests; Better test coverage for CLI
2 parents dddfd33 + 76a5989 commit 3fa2d5f

File tree

12 files changed

+51
-53
lines changed

12 files changed

+51
-53
lines changed

data_diff/__main__.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,18 +307,22 @@ def _main(
307307
else:
308308
db2 = connect(database2, threads2 or threads)
309309

310-
now: datetime = db1.query(current_timestamp(), datetime)
311-
now = now.replace(tzinfo=None)
312-
try:
313-
options = dict(
314-
min_update=max_age and parse_time_before(now, max_age),
315-
max_update=min_age and parse_time_before(now, min_age),
316-
case_sensitive=case_sensitive,
317-
where=where,
318-
)
319-
except ParseError as e:
320-
logging.error(f"Error while parsing age expression: {e}")
321-
return
310+
options = dict(
311+
case_sensitive=case_sensitive,
312+
where=where,
313+
)
314+
315+
if min_age or max_age:
316+
now: datetime = db1.query(current_timestamp(), datetime)
317+
now = now.replace(tzinfo=None)
318+
try:
319+
if max_age:
320+
options["min_update"] = parse_time_before(now, max_age)
321+
if min_age:
322+
options["max_update"] = parse_time_before(now, min_age)
323+
except ParseError as e:
324+
logging.error(f"Error while parsing age expression: {e}")
325+
return
322326

323327
dbs = db1, db2
324328

data_diff/sqeleton/databases/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class BaseDialect(AbstractDialect):
124124
SUPPORTS_INDEXES = False
125125
TYPE_CLASSES: Dict[str, type] = {}
126126

127-
PLACEHOLDER_TABLE = None # Used for Oracle
127+
PLACEHOLDER_TABLE = None # Used for Oracle
128128

129129
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
130130
if offset:

data_diff/sqeleton/databases/duckdb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def set_timezone_to_utc(self) -> str:
114114
def current_timestamp(self) -> str:
115115
return "current_timestamp"
116116

117+
117118
class DuckDB(Database):
118119
dialect = Dialect()
119120
SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it

data_diff/sqeleton/databases/presto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def set_timezone_to_utc(self) -> str:
140140
def current_timestamp(self) -> str:
141141
return "current_timestamp"
142142

143+
143144
class Presto(Database):
144145
dialect = Dialect()
145146
CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>"

data_diff/sqeleton/queries/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
or_,
1919
leftjoin,
2020
rightjoin,
21-
current_timestamp
21+
current_timestamp,
2222
)
2323
from .ast_classes import Expr, ExprNode, Select, Count, BinOp, Explain, In, Code, Column
2424
from .extras import Checksum, NormalizeAsString, ApplyFuncAndNormalizeAsString

data_diff/sqeleton/queries/ast_classes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,6 @@ def compile(self, parent_c: Compiler) -> str:
604604
elif c.dialect.PLACEHOLDER_TABLE:
605605
select += f" FROM {c.dialect.PLACEHOLDER_TABLE}"
606606

607-
608607
if self.where_exprs:
609608
select += " WHERE " + " AND ".join(map(c.compile, self.where_exprs))
610609

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ clickhouse-driver = {version="*", optional=true}
3939
duckdb = {version="^0.6.0", optional=true}
4040

4141
[tool.poetry.dev-dependencies]
42-
arrow = "^1.2.3"
4342
parameterized = "*"
4443
unittest-parallel = "*"
4544
preql = "^0.2.19"

tests/sqeleton/test_database.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def test_table_list(self):
6868

6969
@test_each_database
7070
class TestQueries(unittest.TestCase):
71-
7271
def test_current_timestamp(self):
7372
db = get_conn(self.db_cls)
7473
res = db.query(current_timestamp(), datetime)

tests/test_api.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import arrow
2-
from datetime import datetime
1+
from datetime import datetime, timedelta
32

43
from data_diff import diff_tables, connect_to_table, Algorithm
54
from data_diff.databases import MySQL
@@ -17,20 +16,20 @@ def setUp(self) -> None:
1716

1817
self.conn = self.connection
1918

20-
self.now = now = arrow.get()
19+
self.now = now = datetime.now()
2120

2221
rows = [
2322
(now, "now"),
24-
(self.now.shift(seconds=-10), "a"),
25-
(self.now.shift(seconds=-7), "b"),
26-
(self.now.shift(seconds=-6), "c"),
23+
(self.now - timedelta(seconds=10), "a"),
24+
(self.now - timedelta(seconds=7), "b"),
25+
(self.now - timedelta(seconds=6), "c"),
2726
]
2827

2928
self.conn.query(
3029
[
31-
self.src_table.insert_rows((i, ts.datetime, s) for i, (ts, s) in enumerate(rows)),
30+
self.src_table.insert_rows((i, ts, s) for i, (ts, s) in enumerate(rows)),
3231
self.dst_table.create(self.src_table),
33-
self.src_table.insert_row(len(rows), self.now.shift(seconds=-3).datetime, "3 seconds ago"),
32+
self.src_table.insert_row(len(rows), self.now - timedelta(seconds=3), "3 seconds ago"),
3433
commit,
3534
]
3635
)

tests/test_cli.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import sys
44
from datetime import datetime, timedelta
55

6-
from data_diff.databases import MySQL
7-
from data_diff.sqeleton.queries import commit
6+
from data_diff.sqeleton.queries import commit, current_timestamp
87

9-
from .common import TEST_MYSQL_CONN_STRING, DiffTestCase
8+
from .common import DiffTestCase, CONN_STRINGS
9+
from .test_diff_tables import test_each_database
1010

1111

1212
def run_datadiff_cli(*args):
@@ -20,14 +20,14 @@ def run_datadiff_cli(*args):
2020
return stdout.splitlines()
2121

2222

23+
@test_each_database
2324
class TestCLI(DiffTestCase):
24-
db_cls = MySQL
2525
src_schema = {"id": int, "datetime": datetime, "text_comment": str}
2626

2727
def setUp(self) -> None:
2828
super().setUp()
2929

30-
now = self.connection.query("select now()", datetime)
30+
now = self.connection.query(current_timestamp(), datetime)
3131

3232
rows = [
3333
(now, "now"),
@@ -46,16 +46,16 @@ def setUp(self) -> None:
4646
)
4747

4848
def test_basic(self):
49-
diff = run_datadiff_cli(
50-
TEST_MYSQL_CONN_STRING, self.table_src_name, TEST_MYSQL_CONN_STRING, self.table_dst_name
51-
)
49+
conn_str = CONN_STRINGS[self.db_cls]
50+
diff = run_datadiff_cli(conn_str, self.table_src_name, conn_str, self.table_dst_name)
5251
assert len(diff) == 1
5352

5453
def test_options(self):
54+
conn_str = CONN_STRINGS[self.db_cls]
5555
diff = run_datadiff_cli(
56-
TEST_MYSQL_CONN_STRING,
56+
conn_str,
5757
self.table_src_name,
58-
TEST_MYSQL_CONN_STRING,
58+
conn_str,
5959
self.table_dst_name,
6060
"--bisection-factor",
6161
"16",
@@ -68,4 +68,4 @@ def test_options(self):
6868
"--max-age",
6969
"1h",
7070
)
71-
assert len(diff) == 1
71+
assert len(diff) == 1, diff

0 commit comments

Comments
 (0)