Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit ba494a0

Browse files
author
Marco Baringer
committed
allow joindiff across bigquery projects
1 parent 7abe5a5 commit ba494a0

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

data_diff/info_tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class SegmentInfo:
1313
is_diff: bool = None
1414
diff_count: int = None
1515

16-
rowcounts: Dict[int, int] = {}
16+
rowcounts: Dict[int, int] = {1: 0, 2: 0}
1717
max_rows: int = None
1818

1919
def set_diff(self, diff):

data_diff/joindiff_tables.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class JoinDiffer(TableDiffer):
144144
def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
145145
db = table1.database
146146

147-
if table1.database is not table2.database:
147+
if not table1.database.can_join_with(table2.database):
148148
raise ValueError("Join-diff only works when both tables are in the same database")
149149

150150
table1, table2 = self._threaded_call("with_schema", [table1, table2])
@@ -174,7 +174,7 @@ def _diff_segments(
174174
segment_index=None,
175175
segment_count=None,
176176
):
177-
assert table1.database is table2.database
177+
assert table1.database.can_join_with(table2.database)
178178

179179
if segment_index or table1.min_key or max_rows:
180180
logger.info(
@@ -295,8 +295,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
295295
logger.debug("Done collecting stats for table #%s", i)
296296

297297
def _create_outer_join(self, table1, table2):
298-
db = table1.database
299-
if db is not table2.database:
298+
if not table1.database.can_join_with(table2.database):
300299
raise ValueError("Joindiff only applies to tables within the same database")
301300

302301
keys1 = table1.key_columns
@@ -319,7 +318,7 @@ def _create_outer_join(self, table1, table2):
319318
# Order columns as col1_a, col1_b, col2_a, col2_b, etc.
320319
cols = {k: v for k, v in chain(*zip(a_cols.items(), b_cols.items()))}
321320

322-
all_rows = _outerjoin(db, a, b, keys1, keys2, {**is_diff_cols, **cols})
321+
all_rows = _outerjoin(table1.database, a, b, keys1, keys2, {**is_diff_cols, **cols})
323322
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
324323
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows
325324

data_diff/sqeleton/databases/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ def close(self):
453453
self.is_closed = True
454454
return super().close()
455455

456+
def can_join_with(self, other):
457+
return self is other
456458

457459
class ThreadedDatabase(Database):
458460
"""Access the database through singleton threads.

data_diff/sqeleton/databases/bigquery.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,38 @@ def close(self):
145145
self._client.close()
146146

147147
def select_table_schema(self, path: DbPath) -> str:
148-
schema, name = self._normalize_table_path(path)
149-
148+
project, schema, name = self._normalize_table_path(path)
150149
return (
151150
"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale "
152-
f"FROM {schema}.INFORMATION_SCHEMA.COLUMNS "
151+
f"FROM `{project}`.`{schema}`.INFORMATION_SCHEMA.COLUMNS "
153152
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
154153
)
155154

156155
def query_table_unique_columns(self, path: DbPath) -> List[str]:
157156
return []
158157

158+
def _normalize_table_path(self, path: DbPath) -> DbPath:
159+
if len(path) == 0:
160+
raise ValueError(f"{self.name}: Bad table path for {self}: ()")
161+
elif len(path) == 1:
162+
if self.default_schema:
163+
return [self.project, self.default_schema, path[0]]
164+
else:
165+
return path
166+
elif len(path) == 2:
167+
return [self.project] + path
168+
elif len(path) == 3:
169+
return path
170+
else:
171+
raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: [project.]schema.table")
172+
159173
def parse_table_name(self, name: str) -> DbPath:
160174
path = parse_table_name(name)
161175
return tuple(i for i in self._normalize_table_path(path) if i is not None)
162176

163177
@property
164178
def is_autocommit(self) -> bool:
165179
return True
180+
181+
def can_join_with(self, other):
182+
return True

0 commit comments

Comments
 (0)