Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Refactor database.py -> databases/*.py, each db gets a file. #101

Merged
merged 2 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions data_diff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple, Iterator, Optional, Union

from .database import connect_to_uri
from .databases.connect import connect_to_uri
from .diff_tables import (
TableSegment,
TableDiffer,
Expand All @@ -9,7 +9,6 @@
DbKey,
DbTime,
DbPath,
parse_table_name,
)


Expand All @@ -18,10 +17,11 @@ def connect_to_table(
):
"""Connects to a URI and creates a TableSegment instance"""

db = connect_to_uri(db_uri, thread_count=thread_count)

if isinstance(table_name, str):
table_name = parse_table_name(table_name)
table_name = db.parse_table_name(table_name)

db = connect_to_uri(db_uri, thread_count=thread_count)
return TableSegment(db, table_name, key_column, **kwargs)


Expand Down
11 changes: 5 additions & 6 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
TableDiffer,
DEFAULT_BISECTION_THRESHOLD,
DEFAULT_BISECTION_FACTOR,
parse_table_name,
)
from .database import connect_to_uri, parse_table_name
from .databases.connect import connect_to_uri
from .parse_time import parse_time_before_now, UNITS_STR, ParseError

import rich
Expand Down Expand Up @@ -51,7 +50,7 @@
@click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.")
@click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff")
@click.option("-d", "--debug", is_flag=True, help="Print debug info")
@click.option("--json", 'json_output', is_flag=True, help="Print JSONL output for machine readability")
@click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability")
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
@click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug")
@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.")
Expand Down Expand Up @@ -104,7 +103,7 @@ def main(
try:
threads = int(threads)
except ValueError:
logger.error("Error: threads must be a number, 'auto', or 'serial'.")
logging.error("Error: threads must be a number, 'auto', or 'serial'.")
return
if threads < 1:
logging.error("Error: threads must be >= 1")
Expand All @@ -129,8 +128,8 @@ def main(
logging.error("Error while parsing age expression: %s" % e)
return

table1 = TableSegment(db1, parse_table_name(table1_name), key_column, update_column, columns, **options)
table2 = TableSegment(db2, parse_table_name(table2_name), key_column, update_column, columns, **options)
table1 = TableSegment(db1, db1.parse_table_name(table1_name), key_column, update_column, columns, **options)
table2 = TableSegment(db2, db2.parse_table_name(table2_name), key_column, update_column, columns, **options)

differ = TableDiffer(
bisection_factor=bisection_factor,
Expand Down
Loading