Skip to content

Commit

Permalink
Field names in big query can contain only alphanumeric and underscore (
Browse files Browse the repository at this point in the history
…apache#5641)

* Field names in big query can contain only alphanumeric and underscore

* bad quote

* better place for mutating labels

* lint

* bug fix thanks to mistercrunch

* lint

* lint again

(cherry picked from commit 80e7778)
  • Loading branch information
sumedhsakdeo authored and betodealmeida committed Aug 22, 2018
1 parent fe6af3b commit 5369d6a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.pyc
*.swp
yarn-error.log
_modules
superset/assets/coverage/*
Expand Down
16 changes: 16 additions & 0 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ def align_df_col_names_with_form_data(df, fd):

return df.rename(index=str, columns=rename_cols)

@staticmethod
def mutate_expression_label(label):
return label


class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
Expand Down Expand Up @@ -1414,6 +1418,18 @@ def fetch_data(cls, cursor, limit):
data = [r.values() for r in data]
return data

@staticmethod
def mutate_expression_label(label):
mutated_label = re.sub('[^\w]+', '_', label)
if not re.match('^[a-zA-Z_]+.*', mutated_label):
raise SupersetTemplateException('BigQuery field_name used is invalid {}, '
'should start with a letter or '
'underscore'.format(mutated_label))
if len(mutated_label) > 128:
raise SupersetTemplateException('BigQuery field_name {}, should be atmost '
'128 characters'.format(mutated_label))
return mutated_label

@classmethod
def _get_fields(cls, cols):
"""
Expand Down
8 changes: 6 additions & 2 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def process_metrics(self):
if not isinstance(val, list):
val = [val]
for o in val:
self.metric_dict[self.get_metric_label(o)] = o
label = self.get_metric_label(o)
if isinstance(o, dict):
o['label'] = label
self.metric_dict[label] = o

# Cast to list needed to return serializable object in py3
self.all_metrics = list(self.metric_dict.values())
Expand All @@ -120,7 +123,8 @@ def get_metric_label(self, metric):
if isinstance(metric, string_types):
return metric
if isinstance(metric, dict):
return metric.get('label')
return self.datasource.database.db_engine_spec.mutate_expression_label(
metric.get('label'))

@staticmethod
def handle_js_int_overflow(data):
Expand Down
16 changes: 15 additions & 1 deletion tests/viz_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ def test_cache_timeout(self):


class TableVizTestCase(unittest.TestCase):

class DBEngineSpecMock:
@staticmethod
def mutate_expression_label(label):
return label

class DatabaseMock:
def __init__(self):
self.db_engine_spec = TableVizTestCase.DBEngineSpecMock()

class DatasourceMock:
def __init__(self):
self.database = TableVizTestCase.DatabaseMock()

def test_get_data_applies_percentage(self):
form_data = {
'percent_metrics': [{
Expand All @@ -137,7 +151,7 @@ def test_get_data_applies_percentage(self):
'column': {'column_name': 'value1', 'type': 'DOUBLE'},
}, 'count', 'avg__C'],
}
datasource = Mock()
datasource = TableVizTestCase.DatasourceMock()
raw = {}
raw['SUM(value1)'] = [15, 20, 25, 40]
raw['avg__B'] = [10, 20, 5, 15]
Expand Down

0 comments on commit 5369d6a

Please sign in to comment.