Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid filters containing null value #17168

Merged
merged 1 commit into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,9 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult:
"""
raise NotImplementedError()

def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
Expand Down
4 changes: 3 additions & 1 deletion superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,9 @@ def metrics_and_post_aggs(
)
return aggs, post_aggs

def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Retrieve some values for the given column"""
logger.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
Expand Down
11 changes: 9 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def query(self, query_obj: QueryObjectDict) -> QueryResult:
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()

def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
raise NotImplementedError()


Expand Down Expand Up @@ -712,7 +714,9 @@ def get_fetch_values_predicate(self) -> TextClause:
)
) from ex

def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(
self, column_name: str, limit: int = 10000, contain_null: bool = True,
) -> List[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
Expand All @@ -728,6 +732,9 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
if limit:
qry = qry.limit(limit)

if not contain_null:
qry = qry.where(target_col.get_sqla_col().isnot(None))

if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())

Expand Down
4 changes: 3 additions & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,9 @@ def filter( # pylint: disable=no-self-use
datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
datasource.values_for_column(column, row_limit),
datasource.values_for_column(
column_name=column, limit=row_limit, contain_null=False,
),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)
Expand Down
20 changes: 20 additions & 0 deletions tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,23 @@ def test_labels_expected_on_mutated_query(self):
db.session.delete(table)
db.session.delete(database)
db.session.commit()

def test_values_for_column(self):
table = SqlaTable(
table_name="test_null_in_column",
sql="SELECT 'foo' as foo UNION SELECT 'bar' UNION SELECT NULL",
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=table)

with_null = table.values_for_column(
column_name="foo", limit=10000, contain_null=True
)
assert None in with_null
assert len(with_null) == 3

without_null = table.values_for_column(
column_name="foo", limit=10000, contain_null=False
)
assert None not in without_null
assert len(without_null) == 2