Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Bugfix: Add brackets around WHERE clause #369

Merged
merged 1 commit into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
metavar="COUNT",
)
@click.option(
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space.", metavar="EXPR"
"-w", "--where", default=None, help="An additional 'where' expression to restrict the search space. Beware of SQL Injection!", metavar="EXPR"
)
@click.option("-a", "--algorithm", default=Algorithm.AUTO.value, type=click.Choice([i.value for i in Algorithm]))
@click.option(
Expand Down
2 changes: 1 addition & 1 deletion data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _test_null_keys(self, table1, table2):
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
nulls = ts.database.query(q, list)
if nulls:
raise ValueError("NULL values in one or more primary keys")
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")

def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
logger.debug(f"Collecting stats for table #{i}")
Expand Down
7 changes: 5 additions & 2 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ def __post_init__(self):
f"Error: min_update expected to be smaller than max_update! ({self.min_update} >= {self.max_update})"
)

def _where(self):
return f"({self.where})" if self.where else None

def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self.where)
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))

def with_schema(self) -> "TableSegment":
Expand Down Expand Up @@ -100,7 +103,7 @@ def source_table(self):

def make_select(self):
return self.source_table.where(
*self._make_key_range(), *self._make_update_range(), Code(self.where) if self.where else SKIP
*self._make_key_range(), *self._make_update_range(), Code(self._where()) if self.where else SKIP
)

def get_values(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_api(self):

# test where
diff_id = diff[0][1][0]
where = f"id != {diff_id}"
where = f"id != {diff_id} OR id = 90000000"

t1 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_src_name, where=where)
t2 = connect_to_table(TEST_MYSQL_CONN_STRING, self.table_dst_name, where=where)
Expand Down