From a619cb4ea98342a2fdf7f77587d8aa078c7dccef Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Tue, 29 Mar 2022 20:03:09 +0300 Subject: [PATCH] chore: upgrade black (#19410) --- .pre-commit-config.yaml | 2 +- RELEASING/changelog.py | 8 +- RELEASING/send_email.py | 7 +- scripts/cancel_github_workflows.py | 9 +- .../annotations/commands/update.py | 4 +- .../annotation_layers/annotations/schemas.py | 8 +- superset/cachekeys/api.py | 6 +- superset/cachekeys/schemas.py | 15 +- superset/charts/schemas.py | 151 +++++++++++++----- superset/cli/examples.py | 10 +- superset/cli/importexport.py | 31 +++- superset/cli/main.py | 3 +- superset/cli/thumbnails.py | 6 +- superset/columns/models.py | 5 +- superset/commands/base.py | 2 +- superset/commands/exceptions.py | 4 +- superset/commands/importers/v1/utils.py | 4 +- superset/commands/utils.py | 4 +- superset/common/query_actions.py | 4 +- superset/common/query_context.py | 19 ++- superset/common/query_context_factory.py | 2 +- superset/common/query_context_processor.py | 20 ++- superset/common/query_object.py | 6 +- superset/common/query_object_factory.py | 3 +- superset/common/utils/dataframe_utils.py | 4 +- superset/connectors/druid/models.py | 1 - superset/connectors/druid/views.py | 18 ++- superset/connectors/sqla/models.py | 44 +++-- superset/connectors/sqla/utils.py | 10 +- superset/dashboards/commands/importers/v0.py | 3 +- .../dashboards/filter_sets/commands/base.py | 3 +- superset/dashboards/filter_sets/schemas.py | 6 +- superset/dashboards/filters.py | 16 +- superset/dashboards/permalink/api.py | 9 +- .../dashboards/permalink/commands/create.py | 9 +- superset/dashboards/permalink/schemas.py | 4 +- superset/dashboards/schemas.py | 3 +- superset/databases/api.py | 5 +- superset/databases/commands/exceptions.py | 6 +- superset/databases/commands/validate.py | 6 +- superset/databases/dao.py | 3 +- superset/databases/filters.py | 3 +- superset/databases/schemas.py | 23 ++- superset/datasets/commands/importers/v0.py | 5 +- superset/datasets/commands/update.py | 3 +- superset/db_engine_specs/base.py | 68 ++++++-- superset/db_engine_specs/bigquery.py | 3 +- superset/db_engine_specs/gsheets.py | 17 +- superset/db_engine_specs/hive.py | 5 +- superset/db_engine_specs/mysql.py | 48 +++++- superset/db_engine_specs/postgres.py | 12 +- superset/db_engine_specs/presto.py | 11 +- superset/db_engine_specs/trino.py | 5 +- superset/examples/birth_names.py | 19 ++- superset/examples/world_bank.py | 4 +- superset/exceptions.py | 10 +- superset/explore/form_data/commands/update.py | 3 +- superset/explore/permalink/api.py | 5 +- superset/explore/permalink/commands/create.py | 4 +- superset/explore/permalink/commands/get.py | 3 +- superset/explore/permalink/schemas.py | 4 +- superset/jinja_context.py | 5 +- .../migrations/shared/security_converge.py | 4 +- ...add_type_to_native_filter_configuration.py | 3 +- ...6dca87d1a2_security_converge_dashboards.py | 69 ++++++-- .../2e5a0ee25ed4_refractor_alerting.py | 45 ++++-- .../versions/2f1d15e8a6af_add_alerts.py | 25 ++- ...80_add_creation_method_to_reports_model.py | 4 +- .../40f16acf1ba7_security_converge_reports.py | 22 ++- ...2b4c9e01447_security_converge_databases.py | 42 ++++- ...45731db65d9c_security_converge_datasets.py | 37 ++++- .../49b5a32daba5_add_report_schedules.py | 10 +- .../4b84f97828aa_security_converge_logs.py | 12 +- ...14_add_on_saved_query_delete_tab_state_.py | 5 +- .../73fd22e742ab_add_dynamic_plugins_py.py | 10 +- ...9739cf9_security_converge_css_templates.py | 37 ++++- ...b176a0_add_import_mixing_to_saved_query.py | 5 +- ...5563a02_migrate_iframe_to_dash_markdown.py | 5 +- ...f93db_add_extra_config_column_to_alerts.py | 7 +- ...collapse_alerting_models_into_a_single_.py | 34 +++- ...6560d4f3_change_table_unique_constraint.py | 2 +- ...0de1855_add_uuid_column_to_import_mixin.py | 5 +- .../b5998378c225_add_certificate_to_dbs.py | 3 +- .../b8d3a24d9131_new_dataset_models.py | 49 +++++- ...cb2c78727_security_converge_annotations.py | 52 ++++-- .../c501b7c653a3_add_missing_uuid_column.py | 5 +- .../c82ee8a39623_add_implicit_tags.py | 5 +- ...81977c6_alert_reports_shared_uniqueness.py | 3 +- .../ccb74baaa89b_security_converge_charts.py | 67 ++++++-- ...cc_add_limiting_factor_column_to_query_.py | 6 +- ...7dbf641_security_converge_saved_queries.py | 57 +++++-- ...4e_add_rls_filter_type_and_grouping_key.py | 4 +- superset/models/core.py | 4 +- superset/models/dashboard.py | 3 +- superset/queries/saved_queries/schemas.py | 6 +- superset/reports/commands/base.py | 2 +- superset/reports/commands/execute.py | 13 +- superset/reports/dao.py | 9 +- superset/reports/schemas.py | 4 +- superset/security/manager.py | 5 +- superset/sqllab/command.py | 5 +- superset/sqllab/query_render.py | 7 +- superset/stats_logger.py | 1 - superset/tasks/async_queries.py | 11 +- superset/tasks/scheduler.py | 12 +- superset/tasks/slack_util.py | 3 +- superset/tasks/thumbnails.py | 5 +- superset/temporary_cache/api.py | 6 +- superset/temporary_cache/commands/update.py | 3 +- superset/utils/cache.py | 3 +- superset/utils/core.py | 28 ++-- superset/utils/date_parser.py | 9 +- superset/utils/encrypt.py | 3 +- superset/utils/log.py | 5 +- superset/utils/machine_auth.py | 10 +- superset/utils/mock_data.py | 4 +- superset/utils/pandas_postprocessing/cum.py | 6 +- .../utils/pandas_postprocessing/flatten.py | 5 +- .../utils/pandas_postprocessing/geography.py | 8 +- .../utils/pandas_postprocessing/prophet.py | 5 +- superset/utils/pandas_postprocessing/utils.py | 13 +- superset/utils/profiler.py | 4 +- superset/views/base.py | 2 +- superset/views/core.py | 34 ++-- superset/views/dashboard/views.py | 3 +- superset/views/database/views.py | 12 +- superset/views/datasource/schemas.py | 4 +- superset/views/datasource/views.py | 4 +- superset/views/users/api.py | 2 +- superset/views/utils.py | 4 +- superset/viz.py | 9 +- tests/common/query_context_generator.py | 28 +++- tests/conftest.py | 4 +- tests/fixtures/birth_names.py | 3 +- .../annotation_layers/fixtures.py | 5 +- tests/integration_tests/celery_tests.py | 6 +- tests/integration_tests/charts/api_tests.py | 28 +++- .../charts/data/api_tests.py | 37 +++-- tests/integration_tests/cli_tests.py | 12 +- tests/integration_tests/core_tests.py | 12 +- .../css_templates/api_tests.py | 5 +- tests/integration_tests/dashboard_utils.py | 4 +- .../integration_tests/dashboards/api_tests.py | 32 +++- .../dashboards/filter_sets/get_api_tests.py | 4 +- .../dashboards/permalink/api_tests.py | 3 +- .../security/security_dataset_tests.py | 7 +- .../security/security_rbac_tests.py | 13 +- .../integration_tests/databases/api_tests.py | 14 +- tests/integration_tests/datasets/api_tests.py | 17 +- tests/integration_tests/datasource_tests.py | 12 +- .../db_engine_specs/base_engine_spec_tests.py | 10 +- .../db_engine_specs/bigquery_tests.py | 17 +- .../db_engine_specs/hive_tests.py | 5 +- .../db_engine_specs/pinot_tests.py | 5 +- .../db_engine_specs/postgres_tests.py | 26 ++- .../db_engine_specs/presto_tests.py | 15 +- .../explore/permalink/api_tests.py | 3 +- .../extensions/metastore_cache_test.py | 3 +- .../fixtures/birth_names_dashboard.py | 4 +- .../fixtures/importexport.py | 4 +- tests/integration_tests/form_tests.py | 6 +- .../integration_tests/import_export_tests.py | 6 +- .../key_value/commands/delete_test.py | 17 +- .../key_value/commands/fixtures.py | 5 +- .../key_value/commands/get_test.py | 3 +- .../key_value/commands/update_test.py | 23 ++- .../key_value/commands/upsert_test.py | 23 ++- tests/integration_tests/log_api_tests.py | 10 +- .../migrations/f1410ed7ec95_tests.py | 10 +- tests/integration_tests/model_tests.py | 10 +- .../integration_tests/query_context_tests.py | 16 +- .../reports/commands_tests.py | 76 +++++++-- .../security/guest_token_security_tests.py | 6 +- .../security/migrate_roles_tests.py | 61 +++++-- .../security/row_level_security_tests.py | 3 +- .../integration_tests/sql_validator_tests.py | 5 +- tests/integration_tests/sqla_models_tests.py | 73 +++++++-- tests/integration_tests/sqllab_tests.py | 14 +- .../tasks/async_queries_tests.py | 6 +- tests/integration_tests/utils_tests.py | 10 +- tests/integration_tests/viz_tests.py | 38 ++++- tests/unit_tests/columns/test_models.py | 6 +- tests/unit_tests/core_tests.py | 3 +- .../commands/importers/v1/utils_test.py | 10 +- .../datasets/commands/export_test.py | 6 +- .../commands/importers/v1/import_test.py | 12 +- tests/unit_tests/datasets/test_models.py | 28 +++- tests/unit_tests/db_engine_specs/test_base.py | 5 +- .../db_engine_specs/test_gsheets.py | 32 +++- .../unit_tests/db_engine_specs/test_kusto.py | 10 +- .../unit_tests/db_engine_specs/test_mssql.py | 29 +++- .../unit_tests/db_engine_specs/test_presto.py | 5 +- .../db_engine_specs/test_teradata.py | 5 +- .../unit_tests/db_engine_specs/test_trino.py | 5 +- tests/unit_tests/fixtures/dataframes.py | 18 ++- .../test_contribution.py | 3 +- .../pandas_postprocessing/test_cum.py | 26 ++- .../pandas_postprocessing/test_diff.py | 3 +- .../pandas_postprocessing/test_flatten.py | 15 +- .../pandas_postprocessing/test_pivot.py | 38 +++-- .../pandas_postprocessing/test_prophet.py | 25 ++- .../pandas_postprocessing/test_resample.py | 19 ++- .../pandas_postprocessing/test_rolling.py | 10 +- tests/unit_tests/tables/test_models.py | 8 +- 204 files changed, 2125 insertions(+), 608 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f29891dfddc9..2429a0153f009 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - id: trailing-whitespace args: ["--markdown-linebreak-ext=md"] - repo: https://github.com/psf/black - rev: 19.10b0 + rev: 22.3.0 hooks: - id: black language_version: python3 diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index 441e3092d047e..8e329b5fe0f6c 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -167,7 +167,10 @@ def _get_changelog_version_head(self) -> str: return f"### {self._version} ({self._logs[0].time})" def _parse_change_log( - self, changelog: Dict[str, str], pr_info: Dict[str, str], github_login: str, + self, + changelog: Dict[str, str], + pr_info: Dict[str, str], + github_login: str, ) -> None: formatted_pr = ( f"- [#{pr_info.get('id')}]" @@ -355,7 +358,8 @@ def compare(base_parameters: BaseParameters) -> None: @cli.command("changelog") @click.option( - "--csv", help="The csv filename to export the changelog to", + "--csv", + help="The csv filename to export the changelog to", ) @click.option( "--access_token", diff --git a/RELEASING/send_email.py b/RELEASING/send_email.py index ddf823f1c92b5..a4b4a449665f9 100755 --- a/RELEASING/send_email.py +++ b/RELEASING/send_email.py @@ -106,7 +106,12 @@ def inter_send_email( class BaseParameters(object): def __init__( - self, email: str, username: str, password: str, version: str, version_rc: str, + self, + email: str, + username: str, + password: str, + version: str, + version_rc: str, ) -> None: self.email = email self.username = username diff --git a/scripts/cancel_github_workflows.py b/scripts/cancel_github_workflows.py index 90087fa4f7366..720dc05cbef22 100755 --- a/scripts/cancel_github_workflows.py +++ b/scripts/cancel_github_workflows.py @@ -60,7 +60,8 @@ def request( def list_runs( - repo: str, params: Optional[Dict[str, str]] = None, + repo: str, + params: Optional[Dict[str, str]] = None, ) -> Iterator[Dict[str, Any]]: """List all github workflow runs. Returns: @@ -193,7 +194,11 @@ def cancel_github_workflows( if branch and ":" in branch: [user, branch] = branch.split(":", 2) runs = get_runs( - repo, branch=branch, user=user, statuses=statuses, events=events, + repo, + branch=branch, + user=user, + statuses=statuses, + events=events, ) # sort old jobs to the front, so to cancel older jobs first diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index ccf11b6536582..9e3012acb69bf 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -73,7 +73,9 @@ def validate(self) -> None: # Validate short descr uniqueness on this layer if not AnnotationDAO.validate_update_uniqueness( - layer_id, short_descr, annotation_id=self._model_id, + layer_id, + short_descr, + annotation_id=self._model_id, ): exceptions.append(AnnotationUniquenessValidationError()) else: diff --git a/superset/annotation_layers/annotations/schemas.py b/superset/annotation_layers/annotations/schemas.py index fd03b14f51b4f..5e0bac56f73e0 100644 --- a/superset/annotation_layers/annotations/schemas.py +++ b/superset/annotation_layers/annotations/schemas.py @@ -64,13 +64,17 @@ class AnnotationPostSchema(Schema): ) long_descr = fields.String(description=annotation_long_descr, allow_none=True) start_dttm = fields.DateTime( - description=annotation_start_dttm, required=True, allow_none=False, + description=annotation_start_dttm, + required=True, + allow_none=False, ) end_dttm = fields.DateTime( description=annotation_end_dttm, required=True, allow_none=False ) json_metadata = fields.String( - description=annotation_json_metadata, validate=validate_json, allow_none=True, + description=annotation_json_metadata, + validate=validate_json, + allow_none=True, ) diff --git a/superset/cachekeys/api.py b/superset/cachekeys/api.py index ff19f3a0fe834..6eb0d54d9eef0 100644 --- a/superset/cachekeys/api.py +++ b/superset/cachekeys/api.py @@ -110,8 +110,10 @@ def invalidate(self) -> Response: ) try: - delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member - CacheKey.cache_key.in_(cache_keys) + delete_stmt = ( + CacheKey.__table__.delete().where( # pylint: disable=no-member + CacheKey.cache_key.in_(cache_keys) + ) ) db.session.execute(delete_stmt) db.session.commit() diff --git a/superset/cachekeys/schemas.py b/superset/cachekeys/schemas.py index a97aebdf2c66d..a44a7c545add4 100644 --- a/superset/cachekeys/schemas.py +++ b/superset/cachekeys/schemas.py @@ -25,9 +25,15 @@ class Datasource(Schema): - database_name = fields.String(description="Datasource name",) - datasource_name = fields.String(description=datasource_name_description,) - schema = fields.String(description="Datasource schema",) + database_name = fields.String( + description="Datasource name", + ) + datasource_name = fields.String( + description=datasource_name_description, + ) + schema = fields.String( + description="Datasource schema", + ) datasource_type = fields.String( description=datasource_type_description, validate=validate.OneOf(choices=("druid", "table", "view")), @@ -37,7 +43,8 @@ class Datasource(Schema): class CacheInvalidationRequestSchema(Schema): datasource_uids = fields.List( - fields.String(), description=datasource_uid_description, + fields.String(), + description=datasource_uid_description, ) datasources = fields.List( fields.Nested(Datasource), diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index a887ffd13c45d..2a967eda27f9d 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -279,7 +279,8 @@ class ChartCacheScreenshotResponseSchema(Schema): class ChartDataColumnSchema(Schema): column_name = fields.String( - description="The name of the target column", example="mycol", + description="The name of the target column", + example="mycol", ) type = fields.String(description="Type of target column", example="BIGINT") @@ -325,7 +326,8 @@ class ChartDataAdhocMetricSchema(Schema): example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30", ) timeGrain = fields.String( - description="Optional time grain for temporal filters", example="PT1M", + description="Optional time grain for temporal filters", + example="PT1M", ) isExtra = fields.Boolean( description="Indicates if the filter has been added by a filter component as " @@ -370,7 +372,8 @@ class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSch groupby = ( fields.List( fields.String( - allow_none=False, description="Columns by which to group by", + allow_none=False, + description="Columns by which to group by", ), minLength=1, required=True, @@ -425,7 +428,9 @@ class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchem example="percentile", ) window = fields.Integer( - description="Size of the rolling window in days.", required=True, example=7, + description="Size of the rolling window in days.", + required=True, + example=7, ) rolling_type_options = fields.Dict( desctiption="Optional options to pass to rolling method. Needed for " @@ -592,7 +597,9 @@ class ChartDataBoxplotOptionsSchema(ChartDataPostProcessingOperationOptionsSchem """ groupby = fields.List( - fields.String(description="Columns by which to group the query.",), + fields.String( + description="Columns by which to group the query.", + ), allow_none=True, ) @@ -699,13 +706,16 @@ class ChartDataGeohashDecodeOptionsSchema( """ geohash = fields.String( - description="Name of source column containing geohash string", required=True, + description="Name of source column containing geohash string", + required=True, ) latitude = fields.String( - description="Name of target column for decoded latitude", required=True, + description="Name of target column for decoded latitude", + required=True, ) longitude = fields.String( - description="Name of target column for decoded longitude", required=True, + description="Name of target column for decoded longitude", + required=True, ) @@ -717,13 +727,16 @@ class ChartDataGeohashEncodeOptionsSchema( """ latitude = fields.String( - description="Name of source latitude column", required=True, + description="Name of source latitude column", + required=True, ) longitude = fields.String( - description="Name of source longitude column", required=True, + description="Name of source longitude column", + required=True, ) geohash = fields.String( - description="Name of target column for encoded geohash string", required=True, + description="Name of target column for encoded geohash string", + required=True, ) @@ -739,10 +752,12 @@ class ChartDataGeodeticParseOptionsSchema( required=True, ) latitude = fields.String( - description="Name of target column for decoded latitude", required=True, + description="Name of target column for decoded latitude", + required=True, ) longitude = fields.String( - description="Name of target column for decoded longitude", required=True, + description="Name of target column for decoded longitude", + required=True, ) altitude = fields.String( description="Name of target column for decoded altitude. If omitted, " @@ -789,7 +804,10 @@ class ChartDataPostProcessingOperationSchema(Schema): "column": "age", "options": {"q": 0.25}, }, - "age_mean": {"operator": "mean", "column": "age",}, + "age_mean": { + "operator": "mean", + "column": "age", + }, }, }, ) @@ -816,7 +834,8 @@ class ChartDataFilterSchema(Schema): example=["China", "France", "Japan"], ) grain = fields.String( - description="Optional time grain for temporal filters", example="PT1M", + description="Optional time grain for temporal filters", + example="PT1M", ) isExtra = fields.Boolean( description="Indicates if the filter has been added by a filter component as " @@ -873,7 +892,10 @@ class AnnotationLayerSchema(Schema): description="Type of annotation layer", validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]), ) - color = fields.String(description="Layer color", allow_none=True,) + color = fields.String( + description="Layer color", + allow_none=True, + ) descriptionColumns = fields.List( fields.String(), description="Columns to use as the description. If none are provided, " @@ -911,7 +933,8 @@ class AnnotationLayerSchema(Schema): ) show = fields.Boolean(description="Should the layer be shown", required=True) showLabel = fields.Boolean( - description="Should the label always be shown", allow_none=True, + description="Should the label always be shown", + allow_none=True, ) showMarkers = fields.Boolean( description="Should markers be shown. Only applies to line annotations.", @@ -919,16 +942,34 @@ class AnnotationLayerSchema(Schema): ) sourceType = fields.String( description="Type of source for annotation data", - validate=validate.OneOf(choices=("", "line", "NATIVE", "table",)), + validate=validate.OneOf( + choices=( + "", + "line", + "NATIVE", + "table", + ) + ), ) style = fields.String( description="Line style. Only applies to time-series annotations", - validate=validate.OneOf(choices=("dashed", "dotted", "solid", "longDashed",)), + validate=validate.OneOf( + choices=( + "dashed", + "dotted", + "solid", + "longDashed", + ) + ), ) timeColumn = fields.String( - description="Column with event date or interval start date", allow_none=True, + description="Column with event date or interval start date", + allow_none=True, + ) + titleColumn = fields.String( + description="Column with title", + allow_none=True, ) - titleColumn = fields.String(description="Column with title", allow_none=True,) width = fields.Float( description="Width of annotation line", validate=[ @@ -948,7 +989,10 @@ class AnnotationLayerSchema(Schema): class ChartDataDatasourceSchema(Schema): description = "Chart datasource" - id = fields.Integer(description="Datasource id", required=True,) + id = fields.Integer( + description="Datasource id", + required=True, + ) type = fields.String( description="Datasource type", validate=validate.OneOf(choices=("druid", "table")), @@ -1039,7 +1083,8 @@ class Meta: # pylint: disable=too-few-public-methods allow_none=True, ) is_timeseries = fields.Boolean( - description="Is the `query_object` a timeseries.", allow_none=True, + description="Is the `query_object` a timeseries.", + allow_none=True, ) series_columns = fields.List( fields.Raw(), @@ -1084,7 +1129,8 @@ class Meta: # pylint: disable=too-few-public-methods ], ) order_desc = fields.Boolean( - description="Reverse order. Default: `false`", allow_none=True, + description="Reverse order. Default: `false`", + allow_none=True, ) extras = fields.Nested( ChartDataExtrasSchema, @@ -1151,7 +1197,10 @@ class Meta: # pylint: disable=too-few-public-methods description="Should the rowcount of the actual query be returned", allow_none=True, ) - time_offsets = fields.List(fields.String(), allow_none=True,) + time_offsets = fields.List( + fields.String(), + allow_none=True, + ) class ChartDataQueryContextSchema(Schema): @@ -1190,7 +1239,9 @@ class AnnotationDataSchema(Schema): required=True, ) records = fields.List( - fields.Dict(keys=fields.String(),), + fields.Dict( + keys=fields.String(), + ), description="records mapping the column name to it's value", required=True, ) @@ -1206,10 +1257,14 @@ class ChartDataResponseResult(Schema): allow_none=True, ) cache_key = fields.String( - description="Unique cache key for query object", required=True, allow_none=True, + description="Unique cache key for query object", + required=True, + allow_none=True, ) cached_dttm = fields.String( - description="Cache timestamp", required=True, allow_none=True, + description="Cache timestamp", + required=True, + allow_none=True, ) cache_timeout = fields.Integer( description="Cache timeout in following order: custom timeout, datasource " @@ -1217,12 +1272,19 @@ class ChartDataResponseResult(Schema): required=True, allow_none=True, ) - error = fields.String(description="Error", allow_none=True,) + error = fields.String( + description="Error", + allow_none=True, + ) is_cached = fields.Boolean( - description="Is the result cached", required=True, allow_none=None, + description="Is the result cached", + required=True, + allow_none=None, ) query = fields.String( - description="The executed query statement", required=True, allow_none=False, + description="The executed query statement", + required=True, + allow_none=False, ) status = fields.String( description="Status of the query", @@ -1240,10 +1302,12 @@ class ChartDataResponseResult(Schema): allow_none=False, ) stacktrace = fields.String( - desciption="Stacktrace if there was an error", allow_none=True, + desciption="Stacktrace if there was an error", + allow_none=True, ) rowcount = fields.Integer( - description="Amount of rows in result set", allow_none=False, + description="Amount of rows in result set", + allow_none=False, ) data = fields.List(fields.Dict(), description="A list with results") colnames = fields.List(fields.String(), description="A list of column names") @@ -1273,13 +1337,24 @@ class ChartDataResponseSchema(Schema): class ChartDataAsyncResponseSchema(Schema): channel_id = fields.String( - description="Unique session async channel ID", allow_none=False, + description="Unique session async channel ID", + allow_none=False, + ) + job_id = fields.String( + description="Unique async job ID", + allow_none=False, + ) + user_id = fields.String( + description="Requesting user ID", + allow_none=True, + ) + status = fields.String( + description="Status value for async job", + allow_none=False, ) - job_id = fields.String(description="Unique async job ID", allow_none=False,) - user_id = fields.String(description="Requesting user ID", allow_none=True,) - status = fields.String(description="Status value for async job", allow_none=False,) result_url = fields.String( - description="Unique result URL for fetching async query data", allow_none=False, + description="Unique result URL for fetching async query data", + allow_none=False, ) diff --git a/superset/cli/examples.py b/superset/cli/examples.py index 97394633ba047..26ed9451f0765 100755 --- a/superset/cli/examples.py +++ b/superset/cli/examples.py @@ -93,10 +93,16 @@ def load_examples_run( @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") @click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data") @click.option( - "--only-metadata", "-m", is_flag=True, help="Only load metadata, skip actual data", + "--only-metadata", + "-m", + is_flag=True, + help="Only load metadata, skip actual data", ) @click.option( - "--force", "-f", is_flag=True, help="Force load data even if table already exists", + "--force", + "-f", + is_flag=True, + help="Force load data even if table already exists", ) def load_examples( load_test_data: bool, diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py index 4bc3ee4a2e3c3..d6bbeb67fe287 100755 --- a/superset/cli/importexport.py +++ b/superset/cli/importexport.py @@ -36,10 +36,16 @@ @click.command() @click.argument("directory") @click.option( - "--overwrite", "-o", is_flag=True, help="Overwriting existing metadata definitions", + "--overwrite", + "-o", + is_flag=True, + help="Overwriting existing metadata definitions", ) @click.option( - "--force", "-f", is_flag=True, help="Force load data even if table already exists", + "--force", + "-f", + is_flag=True, + help="Force load data even if table already exists", ) def import_directory(directory: str, overwrite: bool, force: bool) -> None: """Imports configs from a given directory""" @@ -47,7 +53,9 @@ def import_directory(directory: str, overwrite: bool, force: bool) -> None: from superset.examples.utils import load_configs_from_directory load_configs_from_directory( - root=Path(directory), overwrite=overwrite, force_data=force, + root=Path(directory), + overwrite=overwrite, + force_data=force, ) @@ -56,7 +64,9 @@ def import_directory(directory: str, overwrite: bool, force: bool) -> None: @click.command() @with_appcontext @click.option( - "--dashboard-file", "-f", help="Specify the the file to export to", + "--dashboard-file", + "-f", + help="Specify the the file to export to", ) def export_dashboards(dashboard_file: Optional[str] = None) -> None: """Export dashboards to ZIP file""" @@ -90,7 +100,9 @@ def export_dashboards(dashboard_file: Optional[str] = None) -> None: @click.command() @with_appcontext @click.option( - "--datasource-file", "-f", help="Specify the the file to export to", + "--datasource-file", + "-f", + help="Specify the the file to export to", ) def export_datasources(datasource_file: Optional[str] = None) -> None: """Export datasources to ZIP file""" @@ -122,7 +134,9 @@ def export_datasources(datasource_file: Optional[str] = None) -> None: @click.command() @with_appcontext @click.option( - "--path", "-p", help="Path to a single ZIP file", + "--path", + "-p", + help="Path to a single ZIP file", ) @click.option( "--username", @@ -160,7 +174,9 @@ def import_dashboards(path: str, username: Optional[str]) -> None: @click.command() @with_appcontext @click.option( - "--path", "-p", help="Path to a single ZIP file", + "--path", + "-p", + help="Path to a single ZIP file", ) def import_datasources(path: str) -> None: """Import datasources from ZIP file""" @@ -185,7 +201,6 @@ def import_datasources(path: str) -> None: ) sys.exit(1) - else: @click.command() diff --git a/superset/cli/main.py b/superset/cli/main.py index a1a03e9de26d0..aaad7be42e864 100755 --- a/superset/cli/main.py +++ b/superset/cli/main.py @@ -32,7 +32,8 @@ @click.group( - cls=FlaskGroup, context_settings={"token_normalize_func": normalize_token}, + cls=FlaskGroup, + context_settings={"token_normalize_func": normalize_token}, ) @with_appcontext def superset() -> None: diff --git a/superset/cli/thumbnails.py b/superset/cli/thumbnails.py index 5615d947bf30f..5556cff92f620 100755 --- a/superset/cli/thumbnails.py +++ b/superset/cli/thumbnails.py @@ -44,7 +44,11 @@ help="Only process dashboards", ) @click.option( - "--charts_only", "-c", is_flag=True, default=False, help="Only process charts", + "--charts_only", + "-c", + is_flag=True, + default=False, + help="Only process charts", ) @click.option( "--force", diff --git a/superset/columns/models.py b/superset/columns/models.py index 039f73ff57579..fbe045e3d3925 100644 --- a/superset/columns/models.py +++ b/superset/columns/models.py @@ -35,7 +35,10 @@ class Column( - Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, + Model, + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, ): """ A "column". diff --git a/superset/commands/base.py b/superset/commands/base.py index 2592a03fa77e2..552b95feb2e1e 100644 --- a/superset/commands/base.py +++ b/superset/commands/base.py @@ -24,7 +24,7 @@ class BaseCommand(ABC): """ - Base class for all Command like Superset Logic objects + Base class for all Command like Superset Logic objects """ @abstractmethod diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 8b4b717f31d72..2a60318b46e05 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -23,7 +23,7 @@ class CommandException(SupersetException): - """ Common base class for Command exceptions. """ + """Common base class for Command exceptions.""" def __repr__(self) -> str: if self._exception: @@ -52,7 +52,7 @@ def __init__( class CommandInvalidError(CommandException): - """ Common base class for Command Invalid errors. """ + """Common base class for Command Invalid errors.""" status = 422 diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index de86e3f3cc6ab..3999669356ae9 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -79,7 +79,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: def validate_metadata_type( - metadata: Optional[Dict[str, str]], type_: str, exceptions: List[ValidationError], + metadata: Optional[Dict[str, str]], + type_: str, + exceptions: List[ValidationError], ) -> None: """Validate that the type declared in METADATA_FILE_NAME is correct""" if metadata and "type" in metadata: diff --git a/superset/commands/utils.py b/superset/commands/utils.py index c68dde3d5a758..f7564b3de7689 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -34,7 +34,9 @@ def populate_owners( - user: User, owner_ids: Optional[List[int]], default_to_user: bool, + user: User, + owner_ids: Optional[List[int]], + default_to_user: bool, ) -> List[User]: """ Helper function for commands, will fetch all users from owners id's diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 6664db9fe5c1a..2b85125b0e98a 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -79,7 +79,9 @@ def _get_timegrains( def _get_query( - query_context: "QueryContext", query_obj: "QueryObject", _: bool, + query_context: "QueryContext", + query_obj: "QueryObject", + _: bool, ) -> Dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result = {"language": datasource.query_language} diff --git a/superset/common/query_context.py b/superset/common/query_context.py index bc906a92c7321..4a91c6ad6db17 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -69,7 +69,7 @@ def __init__( result_format: ChartDataResultFormat, force: bool = False, custom_cache_timeout: Optional[int] = None, - cache_values: Dict[str, Any] + cache_values: Dict[str, Any], ) -> None: self.datasource = datasource self.result_type = result_type @@ -81,11 +81,16 @@ def __init__( self.cache_values = cache_values self._processor = QueryContextProcessor(self) - def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]: + def get_data( + self, + df: pd.DataFrame, + ) -> Union[str, List[Dict[str, Any]]]: return self._processor.get_data(df) def get_payload( - self, cache_query_context: Optional[bool] = False, force_cached: bool = False, + self, + cache_query_context: Optional[bool] = False, + force_cached: bool = False, ) -> Dict[str, Any]: """Returns the query results with both metadata and data""" return self._processor.get_payload(cache_query_context, force_cached) @@ -103,7 +108,9 @@ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str return self._processor.query_cache_key(query_obj, **kwargs) def get_df_payload( - self, query_obj: QueryObject, force_cached: Optional[bool] = False, + self, + query_obj: QueryObject, + force_cached: Optional[bool] = False, ) -> Dict[str, Any]: return self._processor.get_df_payload(query_obj, force_cached) @@ -111,7 +118,9 @@ def get_query_result(self, query_object: QueryObject) -> QueryResult: return self._processor.get_query_result(query_object) def processing_time_offsets( - self, df: pd.DataFrame, query_object: QueryObject, + self, + df: pd.DataFrame, + query_object: QueryObject, ) -> CachedTimeOffset: return self._processor.processing_time_offsets(df, query_object) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 50eb6b02f3bea..cb40b9540818c 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -50,7 +50,7 @@ def create( result_type: Optional[ChartDataResultType] = None, result_format: Optional[ChartDataResultFormat] = None, force: bool = False, - custom_cache_timeout: Optional[int] = None + custom_cache_timeout: Optional[int] = None, ) -> QueryContext: datasource_model_instance = None if datasource: diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 7954f86cdba0c..202088c853ed7 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -99,7 +99,10 @@ def get_df_payload( """Handles caching around the df payload retrieval""" cache_key = self.query_cache_key(query_obj) cache = QueryCacheManager.get( - cache_key, CacheRegion.DATA, self._query_context.force, force_cached, + cache_key, + CacheRegion.DATA, + self._query_context.force, + force_cached, ) if query_obj and cache_key and not cache.is_loaded: @@ -235,7 +238,9 @@ def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFr return df def processing_time_offsets( # pylint: disable=too-many-locals - self, df: pd.DataFrame, query_object: QueryObject, + self, + df: pd.DataFrame, + query_object: QueryObject, ) -> CachedTimeOffset: query_context = self._query_context # ensure query_object is immutable @@ -250,7 +255,8 @@ def processing_time_offsets( # pylint: disable=too-many-locals for offset in time_offsets: try: query_object_clone.from_dttm = get_past_or_future( - offset, outer_from_dttm, + offset, + outer_from_dttm, ) query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm) except ValueError as ex: @@ -322,7 +328,9 @@ def processing_time_offsets( # pylint: disable=too-many-locals # df left join `offset_metrics_df` offset_df = df_utils.left_join_df( - left_df=df, right_df=offset_metrics_df, join_keys=join_keys, + left_df=df, + right_df=offset_metrics_df, + join_keys=join_keys, ) offset_slice = offset_df[metrics_mapping.values()] @@ -358,7 +366,9 @@ def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]: return df.to_dict(orient="records") def get_payload( - self, cache_query_context: Optional[bool] = False, force_cached: bool = False, + self, + cache_query_context: Optional[bool] = False, + force_cached: bool = False, ) -> Dict[str, Any]: """Returns the query results with both metadata and data""" diff --git a/superset/common/query_object.py b/superset/common/query_object.py index b2e6fe1bb330d..78a76fc3cdee7 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -341,7 +341,11 @@ def to_dict(self) -> Dict[str, Any]: def __repr__(self) -> str: # we use `print` or `logging` output QueryObject - return json.dumps(self.to_dict(), sort_keys=True, default=str,) + return json.dumps( + self.to_dict(), + sort_keys=True, + default=str, + ) def cache_key(self, **extra: Any) -> str: """ diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 2d051e36c729e..64ae99deebabc 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -80,7 +80,8 @@ def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: ) def _process_extras( # pylint: disable=no-self-use - self, extras: Optional[Dict[str, Any]], + self, + extras: Optional[Dict[str, Any]], ) -> Dict[str, Any]: extras = extras or {} return extras diff --git a/superset/common/utils/dataframe_utils.py b/superset/common/utils/dataframe_utils.py index 55d03e6343410..a0216ad54e839 100644 --- a/superset/common/utils/dataframe_utils.py +++ b/superset/common/utils/dataframe_utils.py @@ -26,7 +26,9 @@ def left_join_df( - left_df: pd.DataFrame, right_df: pd.DataFrame, join_keys: List[str], + left_df: pd.DataFrame, + right_df: pd.DataFrame, + join_keys: List[str], ) -> pd.DataFrame: df = left_df.set_index(join_keys).join(right_df.set_index(join_keys)) df.reset_index(inplace=True) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 3a17ec5319374..bc4d0ba3817a4 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -128,7 +128,6 @@ def __init__(self, name: str, post_aggregator: Dict[str, Any]) -> None: self.name = name self.post_aggregator = post_aggregator - except NameError: pass diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index cd7e5d279ba25..b387aff6962e8 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -62,7 +62,9 @@ def ensure_enabled(self) -> None: class DruidColumnInlineView( # pylint: disable=too-many-ancestors - CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, + CompactCRUDMixin, + EnsureEnabledMixin, + SupersetModelView, ): datamodel = SQLAInterface(models.DruidColumn) include_route_methods = RouteMethod.RELATED_VIEW_SET @@ -151,7 +153,9 @@ def post_add(self, item: "DruidColumnInlineView") -> None: class DruidMetricInlineView( # pylint: disable=too-many-ancestors - CompactCRUDMixin, EnsureEnabledMixin, SupersetModelView, + CompactCRUDMixin, + EnsureEnabledMixin, + SupersetModelView, ): datamodel = SQLAInterface(models.DruidMetric) include_route_methods = RouteMethod.RELATED_VIEW_SET @@ -206,7 +210,10 @@ class DruidMetricInlineView( # pylint: disable=too-many-ancestors class DruidClusterModelView( # pylint: disable=too-many-ancestors - EnsureEnabledMixin, SupersetModelView, DeleteMixin, YamlExportMixin, + EnsureEnabledMixin, + SupersetModelView, + DeleteMixin, + YamlExportMixin, ): datamodel = SQLAInterface(models.DruidCluster) include_route_methods = RouteMethod.CRUD_SET @@ -270,7 +277,10 @@ def _delete(self, pk: int) -> None: class DruidDatasourceModelView( # pylint: disable=too-many-ancestors - EnsureEnabledMixin, DatasourceModelView, DeleteMixin, YamlExportMixin, + EnsureEnabledMixin, + DatasourceModelView, + DeleteMixin, + YamlExportMixin, ): datamodel = SQLAInterface(models.DruidDatasource) include_route_methods = RouteMethod.CRUD_SET diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index bbd1b5d84dad7..3f6fee0043703 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -311,7 +311,9 @@ def datasource(self) -> RelationshipProperty: return self.table def get_time_filter( - self, start_dttm: DateTime, end_dttm: DateTime, + self, + start_dttm: DateTime, + end_dttm: DateTime, ) -> ColumnElement: col = self.get_sqla_col(label="__time") l = [] @@ -687,7 +689,9 @@ def external_metadata(self) -> List[Dict[str, str]]: if self.sql: return get_virtual_table_metadata(dataset=self) return get_physical_table_metadata( - database=self.database, table_name=self.table_name, schema_name=self.schema, + database=self.database, + table_name=self.table_name, + schema_name=self.schema, ) @property @@ -1013,7 +1017,10 @@ def _get_sqla_row_level_filters( return all_filters except TemplateError as ex: raise QueryObjectValidationError( - _("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) + _( + "Error in jinja expression in RLS filters: %(msg)s", + msg=ex.message, + ) ) from ex def text(self, clause: str) -> TextClause: @@ -1233,7 +1240,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma ): time_filters.append( columns_by_name[self.main_dttm_col].get_time_filter( - from_dttm, to_dttm, + from_dttm, + to_dttm, ) ) time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm)) @@ -1444,7 +1452,8 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if dttm_col and not db_engine_spec.time_groupby_inline: inner_time_filter = [ dttm_col.get_time_filter( - inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, + inner_from_dttm or from_dttm, + inner_to_dttm or to_dttm, ) ] subq = subq.where(and_(*(where_clause_and + inner_time_filter))) @@ -1473,7 +1482,9 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma orderby = [ ( self._get_series_orderby( - series_limit_metric, metrics_by_name, columns_by_name, + series_limit_metric, + metrics_by_name, + columns_by_name, ), not order_desc, ) @@ -1549,7 +1560,10 @@ def _get_series_orderby( return ob def _normalize_prequery_result_type( - self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn], + self, + row: pd.Series, + dimension: str, + columns_by_name: Dict[str, TableColumn], ) -> Union[str, int, float, bool, Text]: """ Convert a prequery result type to its equivalent Python type. @@ -1594,7 +1608,9 @@ def _get_top_groups( group = [] for dimension in dimensions: value = self._normalize_prequery_result_type( - row, dimension, columns_by_name, + row, + dimension, + columns_by_name, ) group.append(groupby_exprs[dimension] == value) @@ -1933,7 +1949,9 @@ def update_table( # pylint: disable=unused-argument @staticmethod def after_insert( - mapper: Mapper, connection: Connection, target: "SqlaTable", + mapper: Mapper, + connection: Connection, + target: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -1962,7 +1980,9 @@ def after_insert( @staticmethod def after_delete( # pylint: disable=unused-argument - mapper: Mapper, connection: Connection, target: "SqlaTable", + mapper: Mapper, + connection: Connection, + target: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. @@ -1985,7 +2005,9 @@ def after_delete( # pylint: disable=unused-argument @staticmethod def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements - mapper: Mapper, connection: Connection, target: "SqlaTable", + mapper: Mapper, + connection: Connection, + target: "SqlaTable", ) -> None: """ Shadow write the dataset to new models. diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 984eef78f4b76..389c5b9012a3b 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -36,7 +36,9 @@ def get_physical_table_metadata( - database: Database, table_name: str, schema_name: Optional[str] = None, + database: Database, + table_name: str, + schema_name: Optional[str] = None, ) -> List[Dict[str, str]]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec @@ -72,7 +74,11 @@ def get_physical_table_metadata( # from different drivers that fall outside CompileError except Exception: # pylint: disable=broad-except col.update( - {"type": "UNKNOWN", "generic_type": None, "is_dttm": None,} + { + "type": "UNKNOWN", + "generic_type": None, + "is_dttm": None, + } ) return cols diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index f317d51086c9d..a7fbb51c057d5 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -151,7 +151,8 @@ def alter_native_filters(dashboard: Dashboard) -> None: old_dataset_id = target.get("datasetId") if dataset_id_mapping and old_dataset_id is not None: target["datasetId"] = dataset_id_mapping.get( - old_dataset_id, old_dataset_id, + old_dataset_id, + old_dataset_id, ) dashboard.json_metadata = json.dumps(json_metadata) diff --git a/superset/dashboards/filter_sets/commands/base.py b/superset/dashboards/filter_sets/commands/base.py index af31bbd7a1a94..0e902e5e687ca 100644 --- a/superset/dashboards/filter_sets/commands/base.py +++ b/superset/dashboards/filter_sets/commands/base.py @@ -85,7 +85,8 @@ def check_ownership(self) -> None: ) except NotAuthorizedException as err: raise FilterSetForbiddenError( - str(self._filter_set_id), "user not authorized to access the filterset", + str(self._filter_set_id), + "user not authorized to access the filterset", ) from err except FilterSetForbiddenError as err: raise err diff --git a/superset/dashboards/filter_sets/schemas.py b/superset/dashboards/filter_sets/schemas.py index 3c0436d697b23..c1a13b424e815 100644 --- a/superset/dashboards/filter_sets/schemas.py +++ b/superset/dashboards/filter_sets/schemas.py @@ -46,7 +46,11 @@ def _validate_json_meta_data(self, json_meta_data: str) -> None: class FilterSetPostSchema(FilterSetSchema): json_metadata_schema: JsonMetadataSchema = JsonMetadataSchema() # pylint: disable=W0613 - name = fields.String(required=True, allow_none=False, validate=Length(0, 500),) + name = fields.String( + required=True, + allow_none=False, + validate=Length(0, 500), + ) description = fields.String( required=False, allow_none=True, validate=[Length(1, 1000)] ) diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index e398af97b744a..5f79392e71ecd 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -170,7 +170,9 @@ class FilterRelatedRoles(BaseFilter): # pylint: disable=too-few-public-methods def apply(self, query: Query, value: Optional[Any]) -> Query: role_model = security_manager.role_model if value: - return query.filter(role_model.name.ilike(f"%{value}%"),) + return query.filter( + role_model.name.ilike(f"%{value}%"), + ) return query @@ -184,7 +186,15 @@ class DashboardCertifiedFilter(BaseFilter): # pylint: disable=too-few-public-me def apply(self, query: Query, value: Any) -> Query: if value is True: - return query.filter(and_(Dashboard.certified_by.isnot(None),)) + return query.filter( + and_( + Dashboard.certified_by.isnot(None), + ) + ) if value is False: - return query.filter(and_(Dashboard.certified_by.is_(None),)) + return query.filter( + and_( + Dashboard.certified_by.is_(None), + ) + ) return query diff --git a/superset/dashboards/permalink/api.py b/superset/dashboards/permalink/api.py index 978a63cbcf9e8..ca536af8f7400 100644 --- a/superset/dashboards/permalink/api.py +++ b/superset/dashboards/permalink/api.py @@ -104,14 +104,19 @@ def post(self, pk: str) -> Response: try: state = self.add_model_schema.load(request.json) key = CreateDashboardPermalinkCommand( - actor=g.user, dashboard_id=pk, state=state, + actor=g.user, + dashboard_id=pk, + state=state, ).run() http_origin = request.headers.environ.get("HTTP_ORIGIN") url = f"{http_origin}/superset/dashboard/p/{key}/" return self.response(201, key=key, url=url) except (ValidationError, DashboardPermalinkInvalidStateError) as ex: return self.response(400, message=str(ex)) - except (DashboardAccessDeniedError, KeyValueAccessDeniedError,) as ex: + except ( + DashboardAccessDeniedError, + KeyValueAccessDeniedError, + ) as ex: return self.response(403, message=str(ex)) except DashboardNotFoundError as ex: return self.response(404, message=str(ex)) diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index 8a0f6d5973a3d..27ddf0534da88 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -31,7 +31,10 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): def __init__( - self, actor: User, dashboard_id: str, state: DashboardPermalinkState, + self, + actor: User, + dashboard_id: str, + state: DashboardPermalinkState, ): self.actor = actor self.dashboard_id = dashboard_id @@ -46,7 +49,9 @@ def run(self) -> str: "state": self.state, } key = CreateKeyValueCommand( - actor=self.actor, resource=self.resource, value=value, + actor=self.actor, + resource=self.resource, + value=value, ).run() return encode_permalink_key(key=key.id, salt=self.salt) except SQLAlchemyError as ex: diff --git a/superset/dashboards/permalink/schemas.py b/superset/dashboards/permalink/schemas.py index 0e373ce85bd0c..a0fc1cbc5598f 100644 --- a/superset/dashboards/permalink/schemas.py +++ b/superset/dashboards/permalink/schemas.py @@ -19,7 +19,9 @@ class DashboardPermalinkPostSchema(Schema): filterState = fields.Dict( - required=False, allow_none=True, description="Native filter state", + required=False, + allow_none=True, + description="Native filter state", ) urlParams = fields.List( fields.Tuple( diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index 6cb1a3caeee90..9b668df7212d0 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -243,7 +243,8 @@ class DashboardPostSchema(BaseDashboardSchema): ) css = fields.String() json_metadata = fields.String( - description=json_metadata_description, validate=validate_json_metadata, + description=json_metadata_description, + validate=validate_json_metadata, ) published = fields.Boolean(description=published_description) certified_by = fields.String(description=certified_by_description, allow_none=True) diff --git a/superset/databases/api.py b/superset/databases/api.py index 44581ff4c8f54..0de8bcf83e9ba 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -881,7 +881,10 @@ def function_names(self, pk: int) -> Response: database = DatabaseDAO.find_by_id(pk) if not database: return self.response_404() - return self.response(200, function_names=database.function_names,) + return self.response( + 200, + function_names=database.function_names, + ) @expose("/available/", methods=["GET"]) @protect() diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index 9ba58373e197a..bde76c021c88a 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -47,7 +47,8 @@ def __init__(self) -> None: class DatabaseRequiredFieldValidationError(ValidationError): def __init__(self, field_name: str) -> None: super().__init__( - [_("Field is required")], field_name=field_name, + [_("Field is required")], + field_name=field_name, ) @@ -100,7 +101,8 @@ class DatabaseUpdateFailedError(UpdateFailedError): class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors - DatabaseCreateFailedError, DatabaseUpdateFailedError, + DatabaseCreateFailedError, + DatabaseUpdateFailedError, ): message = _("Connection failed, please check your connection settings") diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index e2dcc581d7fbf..91e76d8d55efb 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -57,7 +57,8 @@ def run(self) -> None: raise InvalidEngineError( SupersetError( message=__( - 'Engine "%(engine)s" is not a valid engine.', engine=engine, + 'Engine "%(engine)s" is not a valid engine.', + engine=engine, ), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, @@ -101,7 +102,8 @@ def run(self) -> None: # try to connect sqlalchemy_uri = engine_spec.build_sqlalchemy_uri( # type: ignore - self._properties.get("parameters"), encrypted_extra, + self._properties.get("parameters"), + encrypted_extra, ) if self._model and sqlalchemy_uri == self._model.safe_sqlalchemy_uri(): sqlalchemy_uri = self._model.sqlalchemy_uri_decrypted diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 3d2cdf6d4ffa0..d8813dc8eba3e 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -42,7 +42,8 @@ def validate_uniqueness(database_name: str) -> bool: @staticmethod def validate_update_uniqueness(database_id: int, database_name: str) -> bool: database_query = db.session.query(Database).filter( - Database.database_name == database_name, Database.id != database_id, + Database.database_name == database_name, + Database.id != database_id, ) return not db.session.query(database_query.exists()).scalar() diff --git a/superset/databases/filters.py b/superset/databases/filters.py index bee7d2c7b2134..bd0729767ee4e 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -27,7 +27,8 @@ class DatabaseFilter(BaseFilter): # TODO(bogdan): consider caching. def can_access_databases( # noqa pylint: disable=no-self-use - self, view_menu_name: str, + self, + view_menu_name: str, ) -> Set[str]: return { security_manager.unpack_database_and_schema(vm).database diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 554a0f97cf3df..4fa38415ef2fc 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -308,7 +308,12 @@ def get_engine_spec(engine: Optional[str]) -> Type[BaseEngineSpec]: engine_specs = get_engine_specs() if engine not in engine_specs: raise ValidationError( - [_('Engine "%(engine)s" is not a valid engine.', engine=engine,)] + [ + _( + 'Engine "%(engine)s" is not a valid engine.', + engine=engine, + ) + ] ) return engine_specs[engine] @@ -324,7 +329,9 @@ class Meta: # pylint: disable=too-few-public-methods description="DB-specific parameters for configuration", ) database_name = fields.String( - description=database_name_description, allow_none=True, validate=Length(1, 250), + description=database_name_description, + allow_none=True, + validate=Length(1, 250), ) impersonate_user = fields.Boolean(description=impersonate_user_description) extra = fields.String(description=extra_description, validate=extra_validator) @@ -351,7 +358,9 @@ class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE database_name = fields.String( - description=database_name_description, required=True, validate=Length(1, 250), + description=database_name_description, + required=True, + validate=Length(1, 250), ) cache_timeout = fields.Integer( description=cache_timeout_description, allow_none=True @@ -395,7 +404,9 @@ class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE database_name = fields.String( - description=database_name_description, allow_none=True, validate=Length(1, 250), + description=database_name_description, + allow_none=True, + validate=Length(1, 250), ) cache_timeout = fields.Integer( description=cache_timeout_description, allow_none=True @@ -436,7 +447,9 @@ class Meta: # pylint: disable=too-few-public-methods class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): database_name = fields.String( - description=database_name_description, allow_none=True, validate=Length(1, 250), + description=database_name_description, + allow_none=True, + validate=Length(1, 250), ) impersonate_user = fields.Boolean(description=impersonate_user_description) extra = fields.String(description=extra_description, validate=extra_validator) diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index 1508298a23e88..7f13261edd3d4 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -285,7 +285,10 @@ class ImportDatasetsCommand(BaseCommand): # pylint: disable=unused-argument def __init__( - self, contents: Dict[str, str], *args: Any, **kwargs: Any, + self, + contents: Dict[str, str], + *args: Any, + **kwargs: Any, ): self.contents = contents self._configs: Dict[str, Any] = {} diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index 9ae2bd4a189c9..9d448a6c19392 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -65,7 +65,8 @@ def run(self) -> Model: if self._model: try: dataset = DatasetDAO.update( - model=self._model, properties=self._properties, + model=self._model, + properties=self._properties, ) return dataset except DAOUpdateFailedError as ex: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5c73e2f666e26..200c7c8eac83c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -205,7 +205,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods types.BigInteger(), GenericDataType.NUMERIC, ), - (re.compile(r"^long", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), + ( + re.compile(r"^long", re.IGNORECASE), + types.Float(), + GenericDataType.NUMERIC, + ), ( re.compile(r"^decimal", re.IGNORECASE), types.Numeric(), @@ -216,13 +220,21 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods types.Numeric(), GenericDataType.NUMERIC, ), - (re.compile(r"^float", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC,), + ( + re.compile(r"^float", re.IGNORECASE), + types.Float(), + GenericDataType.NUMERIC, + ), ( re.compile(r"^double", re.IGNORECASE), types.Float(), GenericDataType.NUMERIC, ), - (re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,), + ( + re.compile(r"^real", re.IGNORECASE), + types.REAL, + GenericDataType.NUMERIC, + ), ( re.compile(r"^smallserial", re.IGNORECASE), types.SmallInteger(), @@ -258,7 +270,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods types.DateTime(), GenericDataType.TEMPORAL, ), - (re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,), + ( + re.compile(r"^time", re.IGNORECASE), + types.Time(), + GenericDataType.TEMPORAL, + ), ( re.compile(r"^interval", re.IGNORECASE), types.Interval(), @@ -351,7 +367,8 @@ def get_dbapi_mapped_exception(cls, exception: Exception) -> Exception: @classmethod def get_allow_cost_estimate( # pylint: disable=unused-argument - cls, extra: Dict[str, Any], + cls, + extra: Dict[str, Any], ) -> bool: return False @@ -618,7 +635,10 @@ def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any] @classmethod def extra_table_metadata( # pylint: disable=unused-argument - cls, database: "Database", table_name: str, schema_name: str, + cls, + database: "Database", + table_name: str, + schema_name: str, ) -> Dict[str, Any]: """ Returns engine-specific table metadata @@ -944,7 +964,10 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]: @classmethod def get_table_names( # pylint: disable=unused-argument - cls, database: "Database", inspector: Inspector, schema: Optional[str], + cls, + database: "Database", + inspector: Inspector, + schema: Optional[str], ) -> List[str]: """ Get all tables from schema @@ -961,7 +984,10 @@ def get_table_names( # pylint: disable=unused-argument @classmethod def get_view_names( # pylint: disable=unused-argument - cls, database: "Database", inspector: Inspector, schema: Optional[str], + cls, + database: "Database", + inspector: Inspector, + schema: Optional[str], ) -> List[str]: """ Get all views from schema @@ -1193,7 +1219,10 @@ def modify_url_for_impersonation( @classmethod def update_impersonation_config( - cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + cls, + connect_args: Dict[str, Any], + uri: str, + username: Optional[str], ) -> None: """ Update a configuration dictionary @@ -1207,7 +1236,10 @@ def update_impersonation_config( @classmethod def execute( # pylint: disable=unused-argument - cls, cursor: Any, query: str, **kwargs: Any, + cls, + cursor: Any, + query: str, + **kwargs: Any, ) -> None: """ Execute a SQL query @@ -1333,7 +1365,8 @@ def column_datatype_to_string( @classmethod def get_function_names( # pylint: disable=unused-argument - cls, database: "Database", + cls, + database: "Database", ) -> List[str]: """ Get a list of function names that are able to be called on the database. @@ -1471,7 +1504,9 @@ def has_implicit_cancel(cls) -> bool: @classmethod def get_cancel_query_id( # pylint: disable=unused-argument - cls, cursor: Any, query: Query, + cls, + cursor: Any, + query: Query, ) -> Optional[str]: """ Select identifiers from the database engine that uniquely identifies the @@ -1487,7 +1522,10 @@ def get_cancel_query_id( # pylint: disable=unused-argument @classmethod def cancel_query( # pylint: disable=unused-argument - cls, cursor: Any, query: Query, cancel_query_id: str, + cls, + cursor: Any, + query: Query, + cancel_query_id: str, ) -> bool: """ Cancel query in the underlying database. @@ -1515,7 +1553,7 @@ class BasicParametersSchema(Schema): port = fields.Integer( required=True, description=__("Database port"), - validate=Range(min=0, max=2 ** 16, max_inclusive=False), + validate=Range(min=0, max=2**16, max_inclusive=False), ) database = fields.String(required=True, description=__("Database name")) query = fields.Dict( @@ -1665,7 +1703,7 @@ def validate_parameters( extra={"invalid": ["port"]}, ), ) - if not (isinstance(port, int) and 0 <= port < 2 ** 16): + if not (isinstance(port, int) and 0 <= port < 2**16): errors.append( SupersetError( message=( diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 30e04c4f2fe9b..2c9f81b1bdde0 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -72,7 +72,8 @@ class BigQueryParametersSchema(Schema): credentials_info = EncryptedString( - required=False, description="Contents of BigQuery JSON credentials.", + required=False, + description="Contents of BigQuery JSON credentials.", ) query = fields.Dict(required=False) diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 3889e990795d9..888513f518482 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -82,7 +82,10 @@ class GSheetsEngineSpec(SqliteEngineSpec): @classmethod def modify_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str], + cls, + url: URL, + impersonate_user: bool, + username: Optional[str], ) -> None: if impersonate_user and username is not None: user = security_manager.find_user(username=username) @@ -91,7 +94,10 @@ def modify_url_for_impersonation( @classmethod def extra_table_metadata( - cls, database: "Database", table_name: str, schema_name: str, + cls, + database: "Database", + table_name: str, + schema_name: str, ) -> Dict[str, Any]: engine = cls.get_engine(database, schema=schema_name) with closing(engine.raw_connection()) as conn: @@ -150,7 +156,8 @@ def parameters_json_schema(cls) -> Any: @classmethod def validate_parameters( - cls, parameters: GSheetsParametersType, + cls, + parameters: GSheetsParametersType, ) -> List[SupersetError]: errors: List[SupersetError] = [] encrypted_credentials = parameters.get("service_account_info") or "{}" @@ -173,7 +180,9 @@ def validate_parameters( subject = g.user.email if g.user else None engine = create_engine( - "gsheets://", service_account_info=encrypted_credentials, subject=subject, + "gsheets://", + service_account_info=encrypted_credentials, + subject=subject, ) conn = engine.connect() idx = 0 diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 611a98448b2ce..5f7e01d50271f 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -496,7 +496,10 @@ def modify_url_for_impersonation( @classmethod def update_impersonation_config( - cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + cls, + connect_args: Dict[str, Any], + uri: str, + username: Optional[str], ) -> None: """ Update a configuration dictionary diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 90cbe621faf56..9aa3c85e0fe62 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -74,24 +74,56 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): encryption_parameters = {"ssl": "1"} column_type_mappings = ( - (re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,), - (re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,), + ( + re.compile(r"^int.*", re.IGNORECASE), + INTEGER(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^tinyint", re.IGNORECASE), + TINYINT(), + GenericDataType.NUMERIC, + ), ( re.compile(r"^mediumint", re.IGNORECASE), MEDIUMINT(), GenericDataType.NUMERIC, ), - (re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,), - (re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,), - (re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,), - (re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,), - (re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,), + ( + re.compile(r"^decimal", re.IGNORECASE), + DECIMAL(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^float", re.IGNORECASE), + FLOAT(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^double", re.IGNORECASE), + DOUBLE(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^bit", re.IGNORECASE), + BIT(), + GenericDataType.NUMERIC, + ), + ( + re.compile(r"^tinytext", re.IGNORECASE), + TINYTEXT(), + GenericDataType.STRING, + ), ( re.compile(r"^mediumtext", re.IGNORECASE), MEDIUMTEXT(), GenericDataType.STRING, ), - (re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,), + ( + re.compile(r"^longtext", re.IGNORECASE), + LONGTEXT(), + GenericDataType.STRING, + ), ) _time_grain_expressions = { diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index f6c6888ee97bb..f81e2f7b3e4a6 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -188,8 +188,16 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): lambda match: ARRAY(int(match[2])) if match[2] else String(), GenericDataType.STRING, ), - (re.compile(r"^json.*", re.IGNORECASE), JSON(), GenericDataType.STRING,), - (re.compile(r"^enum.*", re.IGNORECASE), ENUM(), GenericDataType.STRING,), + ( + re.compile(r"^json.*", re.IGNORECASE), + JSON(), + GenericDataType.STRING, + ), + ( + re.compile(r"^enum.*", re.IGNORECASE), + ENUM(), + GenericDataType.STRING, + ), ) @classmethod diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 62e2f349f1e4e..60cb9c7acaca6 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -214,7 +214,10 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: @classmethod def update_impersonation_config( - cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + cls, + connect_args: Dict[str, Any], + uri: str, + username: Optional[str], ) -> None: """ Update a configuration dictionary @@ -487,7 +490,11 @@ def _show_columns( types.VARBINARY(), GenericDataType.STRING, ), - (re.compile(r"^json.*", re.IGNORECASE), types.JSON(), GenericDataType.STRING,), + ( + re.compile(r"^json.*", re.IGNORECASE), + types.JSON(), + GenericDataType.STRING, + ), ( re.compile(r"^date.*", re.IGNORECASE), types.DATETIME(), diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index d902a917a6a11..31e9a0aa7b3a3 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -94,7 +94,10 @@ def adjust_database_uri( @classmethod def update_impersonation_config( - cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + cls, + connect_args: Dict[str, Any], + uri: str, + username: Optional[str], ) -> None: """ Update a configuration dictionary diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 830d39801124d..1380958b2ad4a 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -186,8 +186,16 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[ default_query_context = { "result_format": "json", "result_type": "full", - "datasource": {"id": tbl.id, "type": "table",}, - "queries": [{"columns": [], "metrics": [],},], + "datasource": { + "id": tbl.id, + "type": "table", + }, + "queries": [ + { + "columns": [], + "metrics": [], + }, + ], } admin = get_admin_user() @@ -381,7 +389,12 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[ ), query_context=get_slice_json( default_query_context, - queries=[{"columns": ["name", "state"], "metrics": [metric],}], + queries=[ + { + "columns": ["name", "state"], + "metrics": [metric], + } + ], ), ), ] diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 0b0473f93232e..015a8f9297978 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -43,7 +43,9 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements - only_metadata: bool = False, force: bool = False, sample: bool = False, + only_metadata: bool = False, + force: bool = False, + sample: bool = False, ) -> None: """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" diff --git a/superset/exceptions.py b/superset/exceptions.py index 6b259044862ba..2ae75f122970f 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -129,7 +129,10 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, message, level, extra, + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + message, + level, + extra, ) @@ -144,7 +147,10 @@ def __init__( extra: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( - error, message, level, extra, + error, + message, + level, + extra, ) diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index 279722971f5fb..596c5f6e27ef2 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -39,7 +39,8 @@ class UpdateFormDataCommand(BaseCommand, ABC): def __init__( - self, cmd_params: CommandParameters, + self, + cmd_params: CommandParameters, ): self._cmd_params = cmd_params diff --git a/superset/explore/permalink/api.py b/superset/explore/permalink/api.py index 9a03b71150e35..1d78e4354ae19 100644 --- a/superset/explore/permalink/api.py +++ b/superset/explore/permalink/api.py @@ -162,7 +162,10 @@ def get(self, key: str) -> Response: return self.response(200, **value) except ExplorePermalinkInvalidStateError as ex: return self.response(400, message=str(ex)) - except (ChartAccessDeniedError, DatasetAccessDeniedError,) as ex: + except ( + ChartAccessDeniedError, + DatasetAccessDeniedError, + ) as ex: return self.response(403, message=str(ex)) except (ChartNotFoundError, DatasetNotFoundError) as ex: return self.response(404, message=str(ex)) diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 38e91fb105b27..55fb0820cda0b 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -48,7 +48,9 @@ def run(self) -> str: "state": self.state, } command = CreateKeyValueCommand( - actor=self.actor, resource=self.resource, value=value, + actor=self.actor, + resource=self.resource, + value=value, ) key = command.run() return encode_permalink_key(key=key.id, salt=self.salt) diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index 15a2d495cd6d7..1e3ea1fdc6f92 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -42,7 +42,8 @@ def run(self) -> Optional[ExplorePermalinkValue]: try: key = decode_permalink_id(self.key, salt=self.salt) value: Optional[ExplorePermalinkValue] = GetKeyValueCommand( - resource=self.resource, key=key, + resource=self.resource, + key=key, ).run() if value: chart_id: Optional[int] = value.get("chartId") diff --git a/superset/explore/permalink/schemas.py b/superset/explore/permalink/schemas.py index 7392c2deda250..e1f9d069b853f 100644 --- a/superset/explore/permalink/schemas.py +++ b/superset/explore/permalink/schemas.py @@ -19,7 +19,9 @@ class ExplorePermalinkPostSchema(Schema): formData = fields.Dict( - required=True, allow_none=False, description="Chart form data", + required=True, + allow_none=False, + description="Chart form data", ) urlParams = fields.List( fields.Tuple( diff --git a/superset/jinja_context.py b/superset/jinja_context.py index f21fbbb1b745a..ab3aa5070ca66 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -355,7 +355,10 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: return_value = json.loads(json.dumps(return_value)) except TypeError as ex: raise SupersetTemplateException( - _("Unsupported return value for method %(name)s", name=func.__name__,) + _( + "Unsupported return value for method %(name)s", + name=func.__name__, + ) ) from ex return return_value diff --git a/superset/migrations/shared/security_converge.py b/superset/migrations/shared/security_converge.py index 856efc27bc417..19caa3932b874 100644 --- a/superset/migrations/shared/security_converge.py +++ b/superset/migrations/shared/security_converge.py @@ -214,7 +214,9 @@ def _delete_old_permissions( def migrate_roles( - session: Session, pvm_key_map: PvmMigrationMapType, commit: bool = False, + session: Session, + pvm_key_map: PvmMigrationMapType, + commit: bool = False, ) -> None: """ Migrates all existing roles that have the permissions to be migrated diff --git a/superset/migrations/versions/021b81fe4fbb_add_type_to_native_filter_configuration.py b/superset/migrations/versions/021b81fe4fbb_add_type_to_native_filter_configuration.py index 8238e8f6391ef..9c26159ba0a89 100644 --- a/superset/migrations/versions/021b81fe4fbb_add_type_to_native_filter_configuration.py +++ b/superset/migrations/versions/021b81fe4fbb_add_type_to_native_filter_configuration.py @@ -91,7 +91,8 @@ def downgrade(): for dashboard in session.query(Dashboard).all(): logger.info( - "[RemoveTypeToNativeFilter] Updating Dashobard", dashboard.id, + "[RemoveTypeToNativeFilter] Updating Dashobard", + dashboard.id, ) if not dashboard.json_metadata: logger.info( diff --git a/superset/migrations/versions/1f6dca87d1a2_security_converge_dashboards.py b/superset/migrations/versions/1f6dca87d1a2_security_converge_dashboards.py index f9743873b9b72..ae350848c25d4 100644 --- a/superset/migrations/versions/1f6dca87d1a2_security_converge_dashboards.py +++ b/superset/migrations/versions/1f6dca87d1a2_security_converge_dashboards.py @@ -39,24 +39,63 @@ Pvm, ) -NEW_PVMS = {"Dashboard": ("can_read", "can_write",)} +NEW_PVMS = { + "Dashboard": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("DashboardModelView", "can_add"): (Pvm("Dashboard", "can_write"),), Pvm("DashboardModelView", "can_delete"): (Pvm("Dashboard", "can_write"),), - Pvm("DashboardModelView", "can_download_dashboards",): ( - Pvm("Dashboard", "can_read"), - ), - Pvm("DashboardModelView", "can_edit",): (Pvm("Dashboard", "can_write"),), - Pvm("DashboardModelView", "can_favorite_status",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelView", "can_list",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelView", "can_mulexport",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelView", "can_show",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelView", "muldelete",): (Pvm("Dashboard", "can_write"),), - Pvm("DashboardModelView", "mulexport",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelViewAsync", "can_list",): (Pvm("Dashboard", "can_read"),), - Pvm("DashboardModelViewAsync", "muldelete",): (Pvm("Dashboard", "can_write"),), - Pvm("DashboardModelViewAsync", "mulexport",): (Pvm("Dashboard", "can_read"),), - Pvm("Dashboard", "can_new",): (Pvm("Dashboard", "can_write"),), + Pvm( + "DashboardModelView", + "can_download_dashboards", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelView", + "can_edit", + ): (Pvm("Dashboard", "can_write"),), + Pvm( + "DashboardModelView", + "can_favorite_status", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelView", + "can_list", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelView", + "can_mulexport", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelView", + "can_show", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelView", + "muldelete", + ): (Pvm("Dashboard", "can_write"),), + Pvm( + "DashboardModelView", + "mulexport", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelViewAsync", + "can_list", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "DashboardModelViewAsync", + "muldelete", + ): (Pvm("Dashboard", "can_write"),), + Pvm( + "DashboardModelViewAsync", + "mulexport", + ): (Pvm("Dashboard", "can_read"),), + Pvm( + "Dashboard", + "can_new", + ): (Pvm("Dashboard", "can_write"),), } diff --git a/superset/migrations/versions/2e5a0ee25ed4_refractor_alerting.py b/superset/migrations/versions/2e5a0ee25ed4_refractor_alerting.py index 98bd8a4f54443..4eca5f147bd60 100644 --- a/superset/migrations/versions/2e5a0ee25ed4_refractor_alerting.py +++ b/superset/migrations/versions/2e5a0ee25ed4_refractor_alerting.py @@ -43,9 +43,18 @@ def upgrade(): sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), - sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), - sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -58,10 +67,22 @@ def upgrade(): sa.Column("changed_by_fk", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=False), sa.Column("database_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), - sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), - sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), - sa.ForeignKeyConstraint(["database_id"], ["dbs.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["database_id"], + ["dbs.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -72,8 +93,14 @@ def upgrade(): sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("value", sa.Float(), nullable=True), sa.Column("error_msg", sa.String(length=500), nullable=True), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), - sa.ForeignKeyConstraint(["observer_id"], ["sql_observers.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), + sa.ForeignKeyConstraint( + ["observer_id"], + ["sql_observers.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_index( diff --git a/superset/migrations/versions/2f1d15e8a6af_add_alerts.py b/superset/migrations/versions/2f1d15e8a6af_add_alerts.py index a817f558e8ace..bb85d51983763 100644 --- a/superset/migrations/versions/2f1d15e8a6af_add_alerts.py +++ b/superset/migrations/versions/2f1d15e8a6af_add_alerts.py @@ -49,8 +49,14 @@ def upgrade(): sa.Column("dashboard_id", sa.Integer(), nullable=True), sa.Column("last_eval_dttm", sa.DateTime(), nullable=True), sa.Column("last_state", sa.String(length=10), nullable=True), - sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],), - sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],), + sa.ForeignKeyConstraint( + ["dashboard_id"], + ["dashboards.id"], + ), + sa.ForeignKeyConstraint( + ["slice_id"], + ["slices.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_index(op.f("ix_alerts_active"), "alerts", ["active"], unique=False) @@ -62,7 +68,10 @@ def upgrade(): sa.Column("dttm_end", sa.DateTime(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("state", sa.String(length=10), nullable=True), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), sa.PrimaryKeyConstraint("id"), ) op.create_table( @@ -70,8 +79,14 @@ def upgrade(): sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=True), sa.Column("alert_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), - sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["ab_user.id"], + ), sa.PrimaryKeyConstraint("id"), ) diff --git a/superset/migrations/versions/3317e9248280_add_creation_method_to_reports_model.py b/superset/migrations/versions/3317e9248280_add_creation_method_to_reports_model.py index 68b4383ca9da7..0a5608ae77087 100644 --- a/superset/migrations/versions/3317e9248280_add_creation_method_to_reports_model.py +++ b/superset/migrations/versions/3317e9248280_add_creation_method_to_reports_model.py @@ -34,7 +34,9 @@ def upgrade(): with op.batch_alter_table("report_schedule") as batch_op: batch_op.add_column( sa.Column( - "creation_method", sa.VARCHAR(255), server_default="alerts_reports", + "creation_method", + sa.VARCHAR(255), + server_default="alerts_reports", ) ) batch_op.create_index( diff --git a/superset/migrations/versions/40f16acf1ba7_security_converge_reports.py b/superset/migrations/versions/40f16acf1ba7_security_converge_reports.py index 227c421944a53..2886bfbb15d7a 100644 --- a/superset/migrations/versions/40f16acf1ba7_security_converge_reports.py +++ b/superset/migrations/versions/40f16acf1ba7_security_converge_reports.py @@ -39,13 +39,27 @@ Pvm, ) -NEW_PVMS = {"ReportSchedule": ("can_read", "can_write",)} +NEW_PVMS = { + "ReportSchedule": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("ReportSchedule", "can_list"): (Pvm("ReportSchedule", "can_read"),), Pvm("ReportSchedule", "can_show"): (Pvm("ReportSchedule", "can_read"),), - Pvm("ReportSchedule", "can_add",): (Pvm("ReportSchedule", "can_write"),), - Pvm("ReportSchedule", "can_edit",): (Pvm("ReportSchedule", "can_write"),), - Pvm("ReportSchedule", "can_delete",): (Pvm("ReportSchedule", "can_write"),), + Pvm( + "ReportSchedule", + "can_add", + ): (Pvm("ReportSchedule", "can_write"),), + Pvm( + "ReportSchedule", + "can_edit", + ): (Pvm("ReportSchedule", "can_write"),), + Pvm( + "ReportSchedule", + "can_delete", + ): (Pvm("ReportSchedule", "can_write"),), } diff --git a/superset/migrations/versions/42b4c9e01447_security_converge_databases.py b/superset/migrations/versions/42b4c9e01447_security_converge_databases.py index 3c3c31fb663fc..d8d6a2a3315bd 100644 --- a/superset/migrations/versions/42b4c9e01447_security_converge_databases.py +++ b/superset/migrations/versions/42b4c9e01447_security_converge_databases.py @@ -39,17 +39,43 @@ Pvm, ) -NEW_PVMS = {"Database": ("can_read", "can_write",)} +NEW_PVMS = { + "Database": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("DatabaseView", "can_add"): (Pvm("Database", "can_write"),), Pvm("DatabaseView", "can_delete"): (Pvm("Database", "can_write"),), - Pvm("DatabaseView", "can_edit",): (Pvm("Database", "can_write"),), - Pvm("DatabaseView", "can_list",): (Pvm("Database", "can_read"),), - Pvm("DatabaseView", "can_mulexport",): (Pvm("Database", "can_read"),), - Pvm("DatabaseView", "can_post",): (Pvm("Database", "can_write"),), - Pvm("DatabaseView", "can_show",): (Pvm("Database", "can_read"),), - Pvm("DatabaseView", "muldelete",): (Pvm("Database", "can_write"),), - Pvm("DatabaseView", "yaml_export",): (Pvm("Database", "can_read"),), + Pvm( + "DatabaseView", + "can_edit", + ): (Pvm("Database", "can_write"),), + Pvm( + "DatabaseView", + "can_list", + ): (Pvm("Database", "can_read"),), + Pvm( + "DatabaseView", + "can_mulexport", + ): (Pvm("Database", "can_read"),), + Pvm( + "DatabaseView", + "can_post", + ): (Pvm("Database", "can_write"),), + Pvm( + "DatabaseView", + "can_show", + ): (Pvm("Database", "can_read"),), + Pvm( + "DatabaseView", + "muldelete", + ): (Pvm("Database", "can_write"),), + Pvm( + "DatabaseView", + "yaml_export", + ): (Pvm("Database", "can_read"),), } diff --git a/superset/migrations/versions/45731db65d9c_security_converge_datasets.py b/superset/migrations/versions/45731db65d9c_security_converge_datasets.py index 5b4670857faf4..c7a1c81629ebf 100644 --- a/superset/migrations/versions/45731db65d9c_security_converge_datasets.py +++ b/superset/migrations/versions/45731db65d9c_security_converge_datasets.py @@ -38,7 +38,12 @@ Pvm, ) -NEW_PVMS = {"Dataset": ("can_read", "can_write",)} +NEW_PVMS = { + "Dataset": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("SqlMetricInlineView", "can_add"): (Pvm("Dataset", "can_write"),), Pvm("SqlMetricInlineView", "can_delete"): (Pvm("Dataset", "can_write"),), @@ -50,15 +55,33 @@ Pvm("TableColumnInlineView", "can_edit"): (Pvm("Dataset", "can_write"),), Pvm("TableColumnInlineView", "can_list"): (Pvm("Dataset", "can_read"),), Pvm("TableColumnInlineView", "can_show"): (Pvm("Dataset", "can_read"),), - Pvm("TableModelView", "can_add",): (Pvm("Dataset", "can_write"),), - Pvm("TableModelView", "can_delete",): (Pvm("Dataset", "can_write"),), - Pvm("TableModelView", "can_edit",): (Pvm("Dataset", "can_write"),), + Pvm( + "TableModelView", + "can_add", + ): (Pvm("Dataset", "can_write"),), + Pvm( + "TableModelView", + "can_delete", + ): (Pvm("Dataset", "can_write"),), + Pvm( + "TableModelView", + "can_edit", + ): (Pvm("Dataset", "can_write"),), Pvm("TableModelView", "can_list"): (Pvm("Dataset", "can_read"),), Pvm("TableModelView", "can_mulexport"): (Pvm("Dataset", "can_read"),), Pvm("TableModelView", "can_show"): (Pvm("Dataset", "can_read"),), - Pvm("TableModelView", "muldelete",): (Pvm("Dataset", "can_write"),), - Pvm("TableModelView", "refresh",): (Pvm("Dataset", "can_write"),), - Pvm("TableModelView", "yaml_export",): (Pvm("Dataset", "can_read"),), + Pvm( + "TableModelView", + "muldelete", + ): (Pvm("Dataset", "can_write"),), + Pvm( + "TableModelView", + "refresh", + ): (Pvm("Dataset", "can_write"),), + Pvm( + "TableModelView", + "yaml_export", + ): (Pvm("Dataset", "can_read"),), } diff --git a/superset/migrations/versions/49b5a32daba5_add_report_schedules.py b/superset/migrations/versions/49b5a32daba5_add_report_schedules.py index edf65728b577c..3a3b172bfe65f 100644 --- a/superset/migrations/versions/49b5a32daba5_add_report_schedules.py +++ b/superset/migrations/versions/49b5a32daba5_add_report_schedules.py @@ -114,8 +114,14 @@ def upgrade(): sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False), sa.Column("report_schedule_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(["report_schedule_id"], ["report_schedule.id"],), - sa.ForeignKeyConstraint(["user_id"], ["ab_user.id"],), + sa.ForeignKeyConstraint( + ["report_schedule_id"], + ["report_schedule.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["ab_user.id"], + ), sa.PrimaryKeyConstraint("id"), ) diff --git a/superset/migrations/versions/4b84f97828aa_security_converge_logs.py b/superset/migrations/versions/4b84f97828aa_security_converge_logs.py index 51862e3430513..284b7f7525e3e 100644 --- a/superset/migrations/versions/4b84f97828aa_security_converge_logs.py +++ b/superset/migrations/versions/4b84f97828aa_security_converge_logs.py @@ -37,10 +37,18 @@ revision = "4b84f97828aa" down_revision = "45731db65d9c" -NEW_PVMS = {"Log": ("can_read", "can_write",)} +NEW_PVMS = { + "Log": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("LogModelView", "can_show"): (Pvm("Log", "can_read"),), - Pvm("LogModelView", "can_add",): (Pvm("Log", "can_write"),), + Pvm( + "LogModelView", + "can_add", + ): (Pvm("Log", "can_write"),), Pvm("LogModelView", "can_list"): (Pvm("Log", "can_read"),), } diff --git a/superset/migrations/versions/58df9d617f14_add_on_saved_query_delete_tab_state_.py b/superset/migrations/versions/58df9d617f14_add_on_saved_query_delete_tab_state_.py index 220370f828049..57e13cf1488e6 100644 --- a/superset/migrations/versions/58df9d617f14_add_on_saved_query_delete_tab_state_.py +++ b/superset/migrations/versions/58df9d617f14_add_on_saved_query_delete_tab_state_.py @@ -62,5 +62,8 @@ def downgrade(): ) batch_op.create_foreign_key( - "saved_query_id", "saved_query", ["saved_query_id"], ["id"], + "saved_query_id", + "saved_query", + ["saved_query_id"], + ["id"], ) diff --git a/superset/migrations/versions/73fd22e742ab_add_dynamic_plugins_py.py b/superset/migrations/versions/73fd22e742ab_add_dynamic_plugins_py.py index 0a9e37aac2ec6..e4c2d0bc519ff 100644 --- a/superset/migrations/versions/73fd22e742ab_add_dynamic_plugins_py.py +++ b/superset/migrations/versions/73fd22e742ab_add_dynamic_plugins_py.py @@ -42,8 +42,14 @@ def upgrade(): sa.Column("bundle_url", sa.String(length=1000), nullable=False), sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("changed_by_fk", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), - sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("key"), sa.UniqueConstraint("name"), diff --git a/superset/migrations/versions/8ee129739cf9_security_converge_css_templates.py b/superset/migrations/versions/8ee129739cf9_security_converge_css_templates.py index cc641006d6afb..401dc5c4b143f 100644 --- a/superset/migrations/versions/8ee129739cf9_security_converge_css_templates.py +++ b/superset/migrations/versions/8ee129739cf9_security_converge_css_templates.py @@ -39,16 +39,39 @@ Pvm, ) -NEW_PVMS = {"CssTemplate": ("can_read", "can_write",)} +NEW_PVMS = { + "CssTemplate": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("CssTemplateModelView", "can_list"): (Pvm("CssTemplate", "can_read"),), Pvm("CssTemplateModelView", "can_show"): (Pvm("CssTemplate", "can_read"),), - Pvm("CssTemplateModelView", "can_add",): (Pvm("CssTemplate", "can_write"),), - Pvm("CssTemplateModelView", "can_edit",): (Pvm("CssTemplate", "can_write"),), - Pvm("CssTemplateModelView", "can_delete",): (Pvm("CssTemplate", "can_write"),), - Pvm("CssTemplateModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),), - Pvm("CssTemplateAsyncModelView", "can_list",): (Pvm("CssTemplate", "can_read"),), - Pvm("CssTemplateAsyncModelView", "muldelete",): (Pvm("CssTemplate", "can_write"),), + Pvm( + "CssTemplateModelView", + "can_add", + ): (Pvm("CssTemplate", "can_write"),), + Pvm( + "CssTemplateModelView", + "can_edit", + ): (Pvm("CssTemplate", "can_write"),), + Pvm( + "CssTemplateModelView", + "can_delete", + ): (Pvm("CssTemplate", "can_write"),), + Pvm( + "CssTemplateModelView", + "muldelete", + ): (Pvm("CssTemplate", "can_write"),), + Pvm( + "CssTemplateAsyncModelView", + "can_list", + ): (Pvm("CssTemplate", "can_read"),), + Pvm( + "CssTemplateAsyncModelView", + "muldelete", + ): (Pvm("CssTemplate", "can_write"),), } diff --git a/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py b/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py index 2dc38e214a633..57d22aa089aa2 100644 --- a/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py +++ b/superset/migrations/versions/96e99fb176a0_add_import_mixing_to_saved_query.py @@ -65,7 +65,10 @@ def upgrade(): with op.batch_alter_table("saved_query") as batch_op: batch_op.add_column( sa.Column( - "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, + "uuid", + UUIDType(binary=True), + primary_key=False, + default=uuid4, ), ) except OperationalError: diff --git a/superset/migrations/versions/978245563a02_migrate_iframe_to_dash_markdown.py b/superset/migrations/versions/978245563a02_migrate_iframe_to_dash_markdown.py index ccc845cdc0041..6b63c468eca0c 100644 --- a/superset/migrations/versions/978245563a02_migrate_iframe_to_dash_markdown.py +++ b/superset/migrations/versions/978245563a02_migrate_iframe_to_dash_markdown.py @@ -159,7 +159,10 @@ def upgrade(): for key_to_remove in keys_to_remove: del position_dict[key_to_remove] dashboard.position_json = json.dumps( - position_dict, indent=None, separators=(",", ":"), sort_keys=True, + position_dict, + indent=None, + separators=(",", ":"), + sort_keys=True, ) session.merge(dashboard) diff --git a/superset/migrations/versions/abe27eaf93db_add_extra_config_column_to_alerts.py b/superset/migrations/versions/abe27eaf93db_add_extra_config_column_to_alerts.py index a956058434118..5a20fc894a639 100644 --- a/superset/migrations/versions/abe27eaf93db_add_extra_config_column_to_alerts.py +++ b/superset/migrations/versions/abe27eaf93db_add_extra_config_column_to_alerts.py @@ -39,7 +39,12 @@ def upgrade(): with op.batch_alter_table("report_schedule") as batch_op: batch_op.add_column( - sa.Column("extra", sa.Text(), nullable=True, default="{}",), + sa.Column( + "extra", + sa.Text(), + nullable=True, + default="{}", + ), ) bind.execute(report_schedule.update().values({"extra": "{}"})) with op.batch_alter_table("report_schedule") as batch_op: diff --git a/superset/migrations/versions/af30ca79208f_collapse_alerting_models_into_a_single_.py b/superset/migrations/versions/af30ca79208f_collapse_alerting_models_into_a_single_.py index 4d3c5983e3fb1..9502a66f4db32 100644 --- a/superset/migrations/versions/af30ca79208f_collapse_alerting_models_into_a_single_.py +++ b/superset/migrations/versions/af30ca79208f_collapse_alerting_models_into_a_single_.py @@ -118,7 +118,8 @@ def upgrade(): sa.Column("validator_config", sa.Text(), default="", nullable=True), ) op.add_column( - "alerts", sa.Column("database_id", sa.Integer(), default=0, nullable=False), + "alerts", + sa.Column("database_id", sa.Integer(), default=0, nullable=False), ) op.add_column("alerts", sa.Column("sql", sa.Text(), default="", nullable=False)) op.add_column( @@ -159,7 +160,10 @@ def upgrade(): sa.Column("alert_id", sa.Integer(), nullable=True), sa.Column("value", sa.Float(), nullable=True), sa.Column("error_msg", sa.String(length=500), nullable=True), - sa.ForeignKeyConstraint(["alert_id"], ["alerts.id"],), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alerts.id"], + ), sa.PrimaryKeyConstraint("id"), ) else: @@ -192,7 +196,11 @@ def downgrade(): sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("changed_on", sa.DateTime(), nullable=True), sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("validator_type", sa.String(length=100), nullable=False,), + sa.Column( + "validator_type", + sa.String(length=100), + nullable=False, + ), sa.Column("config", sa.Text(), nullable=True), sa.Column("created_by_fk", sa.Integer(), autoincrement=False, nullable=True), sa.Column("changed_by_fk", sa.Integer(), autoincrement=False, nullable=True), @@ -261,10 +269,22 @@ def downgrade(): sa.Column("created_by_fk", sa.Integer(), nullable=True), sa.Column("created_on", sa.DateTime(), nullable=True), sa.Column("slack_channel", sa.Text(), nullable=True), - sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"],), - sa.ForeignKeyConstraint(["slice_id"], ["slices.id"],), - sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"],), - sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"],), + sa.ForeignKeyConstraint( + ["dashboard_id"], + ["dashboards.id"], + ), + sa.ForeignKeyConstraint( + ["slice_id"], + ["slices.id"], + ), + sa.ForeignKeyConstraint( + ["created_by_fk"], + ["ab_user.id"], + ), + sa.ForeignKeyConstraint( + ["changed_by_fk"], + ["ab_user.id"], + ), sa.PrimaryKeyConstraint("id"), ) else: diff --git a/superset/migrations/versions/b4456560d4f3_change_table_unique_constraint.py b/superset/migrations/versions/b4456560d4f3_change_table_unique_constraint.py index d22f72cf03c7d..5ce049f8e4a2e 100644 --- a/superset/migrations/versions/b4456560d4f3_change_table_unique_constraint.py +++ b/superset/migrations/versions/b4456560d4f3_change_table_unique_constraint.py @@ -42,6 +42,6 @@ def upgrade(): def downgrade(): try: # Trying since sqlite doesn't like constraints - op.drop_constraint(u"_customer_location_uc", "tables", type_="unique") + op.drop_constraint("_customer_location_uc", "tables", type_="unique") except Exception: pass diff --git a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py index 7bef33bc87260..747ec9fb4f77f 100644 --- a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py +++ b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py @@ -171,7 +171,10 @@ def upgrade(): with op.batch_alter_table(table_name) as batch_op: batch_op.add_column( sa.Column( - "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, + "uuid", + UUIDType(binary=True), + primary_key=False, + default=uuid4, ), ) diff --git a/superset/migrations/versions/b5998378c225_add_certificate_to_dbs.py b/superset/migrations/versions/b5998378c225_add_certificate_to_dbs.py index ae99486807194..404ea96e4402a 100644 --- a/superset/migrations/versions/b5998378c225_add_certificate_to_dbs.py +++ b/superset/migrations/versions/b5998378c225_add_certificate_to_dbs.py @@ -36,7 +36,8 @@ def upgrade(): kwargs: Dict[str, str] = {} bind = op.get_bind() op.add_column( - "dbs", sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs), + "dbs", + sa.Column("server_cert", sa.LargeBinary(), nullable=True, **kwargs), ) diff --git a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py index d0a999d8542d1..35419e0066cd4 100644 --- a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py +++ b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py @@ -379,16 +379,46 @@ def upgrade(): sa.Column("name", sa.TEXT(), nullable=False), sa.Column("type", sa.TEXT(), nullable=False), sa.Column("expression", sa.TEXT(), nullable=False), - sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=True,), + sa.Column( + "is_physical", + sa.BOOLEAN(), + nullable=False, + default=True, + ), sa.Column("description", sa.TEXT(), nullable=True), sa.Column("warning_text", sa.TEXT(), nullable=True), sa.Column("unit", sa.TEXT(), nullable=True), sa.Column("is_temporal", sa.BOOLEAN(), nullable=False), - sa.Column("is_spatial", sa.BOOLEAN(), nullable=False, default=False,), - sa.Column("is_partition", sa.BOOLEAN(), nullable=False, default=False,), - sa.Column("is_aggregation", sa.BOOLEAN(), nullable=False, default=False,), - sa.Column("is_additive", sa.BOOLEAN(), nullable=False, default=False,), - sa.Column("is_increase_desired", sa.BOOLEAN(), nullable=False, default=True,), + sa.Column( + "is_spatial", + sa.BOOLEAN(), + nullable=False, + default=False, + ), + sa.Column( + "is_partition", + sa.BOOLEAN(), + nullable=False, + default=False, + ), + sa.Column( + "is_aggregation", + sa.BOOLEAN(), + nullable=False, + default=False, + ), + sa.Column( + "is_additive", + sa.BOOLEAN(), + nullable=False, + default=False, + ), + sa.Column( + "is_increase_desired", + sa.BOOLEAN(), + nullable=False, + default=True, + ), sa.Column( "is_managed_externally", sa.Boolean(), @@ -459,7 +489,12 @@ def upgrade(): sa.Column("sqlatable_id", sa.INTEGER(), nullable=True), sa.Column("name", sa.TEXT(), nullable=False), sa.Column("expression", sa.TEXT(), nullable=False), - sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column( + "is_physical", + sa.BOOLEAN(), + nullable=False, + default=False, + ), sa.Column( "is_managed_externally", sa.Boolean(), diff --git a/superset/migrations/versions/c25cb2c78727_security_converge_annotations.py b/superset/migrations/versions/c25cb2c78727_security_converge_annotations.py index 33099dd2e74b2..eedc721c98788 100644 --- a/superset/migrations/versions/c25cb2c78727_security_converge_annotations.py +++ b/superset/migrations/versions/c25cb2c78727_security_converge_annotations.py @@ -39,19 +39,51 @@ down_revision = "ccb74baaa89b" -NEW_PVMS = {"Annotation": ("can_read", "can_write",)} +NEW_PVMS = { + "Annotation": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("AnnotationLayerModelView", "can_delete"): (Pvm("Annotation", "can_write"),), Pvm("AnnotationLayerModelView", "can_list"): (Pvm("Annotation", "can_read"),), - Pvm("AnnotationLayerModelView", "can_show",): (Pvm("Annotation", "can_read"),), - Pvm("AnnotationLayerModelView", "can_add",): (Pvm("Annotation", "can_write"),), - Pvm("AnnotationLayerModelView", "can_edit",): (Pvm("Annotation", "can_write"),), - Pvm("AnnotationModelView", "can_annotation",): (Pvm("Annotation", "can_read"),), - Pvm("AnnotationModelView", "can_show",): (Pvm("Annotation", "can_read"),), - Pvm("AnnotationModelView", "can_add",): (Pvm("Annotation", "can_write"),), - Pvm("AnnotationModelView", "can_delete",): (Pvm("Annotation", "can_write"),), - Pvm("AnnotationModelView", "can_edit",): (Pvm("Annotation", "can_write"),), - Pvm("AnnotationModelView", "can_list",): (Pvm("Annotation", "can_read"),), + Pvm( + "AnnotationLayerModelView", + "can_show", + ): (Pvm("Annotation", "can_read"),), + Pvm( + "AnnotationLayerModelView", + "can_add", + ): (Pvm("Annotation", "can_write"),), + Pvm( + "AnnotationLayerModelView", + "can_edit", + ): (Pvm("Annotation", "can_write"),), + Pvm( + "AnnotationModelView", + "can_annotation", + ): (Pvm("Annotation", "can_read"),), + Pvm( + "AnnotationModelView", + "can_show", + ): (Pvm("Annotation", "can_read"),), + Pvm( + "AnnotationModelView", + "can_add", + ): (Pvm("Annotation", "can_write"),), + Pvm( + "AnnotationModelView", + "can_delete", + ): (Pvm("Annotation", "can_write"),), + Pvm( + "AnnotationModelView", + "can_edit", + ): (Pvm("Annotation", "can_write"),), + Pvm( + "AnnotationModelView", + "can_list", + ): (Pvm("Annotation", "can_read"),), } diff --git a/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py b/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py index f8b252ed0cba3..4cfbc104c01db 100644 --- a/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py +++ b/superset/migrations/versions/c501b7c653a3_add_missing_uuid_column.py @@ -67,7 +67,10 @@ def upgrade(): with op.batch_alter_table(table_name) as batch_op: batch_op.add_column( sa.Column( - "uuid", UUIDType(binary=True), primary_key=False, default=uuid4, + "uuid", + UUIDType(binary=True), + primary_key=False, + default=uuid4, ), ) add_uuids(model, table_name, session) diff --git a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py index 3bab3f6ec3af9..ad809d3e4564e 100644 --- a/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py +++ b/superset/migrations/versions/c82ee8a39623_add_implicit_tags.py @@ -52,7 +52,10 @@ class AuditMixinNullable(AuditMixin): @declared_attr def created_by_fk(self) -> Column: return Column( - Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True, + Integer, + ForeignKey("ab_user.id"), + default=self.get_user_id, + nullable=True, ) @declared_attr diff --git a/superset/migrations/versions/c878781977c6_alert_reports_shared_uniqueness.py b/superset/migrations/versions/c878781977c6_alert_reports_shared_uniqueness.py index 62b4501bc343c..bb8f628bd7c24 100644 --- a/superset/migrations/versions/c878781977c6_alert_reports_shared_uniqueness.py +++ b/superset/migrations/versions/c878781977c6_alert_reports_shared_uniqueness.py @@ -80,7 +80,8 @@ def upgrade(): if isinstance(bind.dialect, MySQLDialect): op.drop_index( - op.f("name"), table_name="report_schedule", + op.f("name"), + table_name="report_schedule", ) if isinstance(bind.dialect, PGDialect): diff --git a/superset/migrations/versions/ccb74baaa89b_security_converge_charts.py b/superset/migrations/versions/ccb74baaa89b_security_converge_charts.py index d025cd5762f98..66fc547d54496 100644 --- a/superset/migrations/versions/ccb74baaa89b_security_converge_charts.py +++ b/superset/migrations/versions/ccb74baaa89b_security_converge_charts.py @@ -39,22 +39,63 @@ Pvm, ) -NEW_PVMS = {"Chart": ("can_read", "can_write",)} +NEW_PVMS = { + "Chart": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("SliceModelView", "can_list"): (Pvm("Chart", "can_read"),), Pvm("SliceModelView", "can_show"): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "can_edit",): (Pvm("Chart", "can_write"),), - Pvm("SliceModelView", "can_delete",): (Pvm("Chart", "can_write"),), - Pvm("SliceModelView", "can_add",): (Pvm("Chart", "can_write"),), - Pvm("SliceModelView", "can_download",): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "muldelete",): (Pvm("Chart", "can_write"),), - Pvm("SliceModelView", "can_mulexport",): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "can_favorite_status",): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "can_cache_screenshot",): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "can_screenshot",): (Pvm("Chart", "can_read"),), - Pvm("SliceModelView", "can_data_from_cache",): (Pvm("Chart", "can_read"),), - Pvm("SliceAsync", "can_list",): (Pvm("Chart", "can_read"),), - Pvm("SliceAsync", "muldelete",): (Pvm("Chart", "can_write"),), + Pvm( + "SliceModelView", + "can_edit", + ): (Pvm("Chart", "can_write"),), + Pvm( + "SliceModelView", + "can_delete", + ): (Pvm("Chart", "can_write"),), + Pvm( + "SliceModelView", + "can_add", + ): (Pvm("Chart", "can_write"),), + Pvm( + "SliceModelView", + "can_download", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceModelView", + "muldelete", + ): (Pvm("Chart", "can_write"),), + Pvm( + "SliceModelView", + "can_mulexport", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceModelView", + "can_favorite_status", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceModelView", + "can_cache_screenshot", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceModelView", + "can_screenshot", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceModelView", + "can_data_from_cache", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceAsync", + "can_list", + ): (Pvm("Chart", "can_read"),), + Pvm( + "SliceAsync", + "muldelete", + ): (Pvm("Chart", "can_write"),), } diff --git a/superset/migrations/versions/d416d0d715cc_add_limiting_factor_column_to_query_.py b/superset/migrations/versions/d416d0d715cc_add_limiting_factor_column_to_query_.py index 47db03d4555b5..532da44886643 100644 --- a/superset/migrations/versions/d416d0d715cc_add_limiting_factor_column_to_query_.py +++ b/superset/migrations/versions/d416d0d715cc_add_limiting_factor_column_to_query_.py @@ -33,7 +33,11 @@ def upgrade(): with op.batch_alter_table("query") as batch_op: batch_op.add_column( - sa.Column("limiting_factor", sa.VARCHAR(255), server_default="UNKNOWN",) + sa.Column( + "limiting_factor", + sa.VARCHAR(255), + server_default="UNKNOWN", + ) ) diff --git a/superset/migrations/versions/e38177dbf641_security_converge_saved_queries.py b/superset/migrations/versions/e38177dbf641_security_converge_saved_queries.py index 85ce431758dd4..d3342fe535686 100644 --- a/superset/migrations/versions/e38177dbf641_security_converge_saved_queries.py +++ b/superset/migrations/versions/e38177dbf641_security_converge_saved_queries.py @@ -39,20 +39,55 @@ Pvm, ) -NEW_PVMS = {"SavedQuery": ("can_read", "can_write",)} +NEW_PVMS = { + "SavedQuery": ( + "can_read", + "can_write", + ) +} PVM_MAP = { Pvm("SavedQueryView", "can_list"): (Pvm("SavedQuery", "can_read"),), Pvm("SavedQueryView", "can_show"): (Pvm("SavedQuery", "can_read"),), - Pvm("SavedQueryView", "can_add",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryView", "can_edit",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryView", "can_delete",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryView", "muldelete",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryView", "can_mulexport",): (Pvm("SavedQuery", "can_read"),), - Pvm("SavedQueryViewApi", "can_show",): (Pvm("SavedQuery", "can_read"),), - Pvm("SavedQueryViewApi", "can_edit",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryViewApi", "can_list",): (Pvm("SavedQuery", "can_read"),), - Pvm("SavedQueryViewApi", "can_add",): (Pvm("SavedQuery", "can_write"),), - Pvm("SavedQueryViewApi", "muldelete",): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryView", + "can_add", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryView", + "can_edit", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryView", + "can_delete", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryView", + "muldelete", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryView", + "can_mulexport", + ): (Pvm("SavedQuery", "can_read"),), + Pvm( + "SavedQueryViewApi", + "can_show", + ): (Pvm("SavedQuery", "can_read"),), + Pvm( + "SavedQueryViewApi", + "can_edit", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryViewApi", + "can_list", + ): (Pvm("SavedQuery", "can_read"),), + Pvm( + "SavedQueryViewApi", + "can_add", + ): (Pvm("SavedQuery", "can_write"),), + Pvm( + "SavedQueryViewApi", + "muldelete", + ): (Pvm("SavedQuery", "can_write"),), } diff --git a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py index 01fcf60e93357..7e1de6112889e 100644 --- a/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py +++ b/superset/migrations/versions/e5ef6828ac4e_add_rls_filter_type_and_grouping_key.py @@ -53,6 +53,8 @@ def upgrade(): def downgrade(): with op.batch_alter_table("row_level_security_filters") as batch_op: - batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),) + batch_op.drop_index( + op.f("ix_row_level_security_filters_filter_type"), + ) batch_op.drop_column("filter_type") batch_op.drop_column("group_key") diff --git a/superset/models/core.py b/superset/models/core.py index 450ddc965e12f..fcc7cf16d8ef2 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -322,7 +322,9 @@ def set_sqlalchemy_uri(self, uri: str) -> None: self.sqlalchemy_uri = str(conn) # hides the password def get_effective_user( - self, object_url: URL, user_name: Optional[str] = None, + self, + object_url: URL, + user_name: Optional[str] = None, ) -> Optional[str]: """ Get the effective user, especially during impersonation. diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 812b0544689e3..7a8710af0d4af 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -344,7 +344,8 @@ def clear_cache_for_slice(cls, slice_id: int) -> None: @debounce(0.1) def clear_cache_for_datasource(cls, datasource_id: int) -> None: filter_query = select( - [dashboard_slices.c.dashboard_id], distinct=True, + [dashboard_slices.c.dashboard_id], + distinct=True, ).select_from( join( dashboard_slices, diff --git a/superset/queries/saved_queries/schemas.py b/superset/queries/saved_queries/schemas.py index ca2ef800a67e9..1fbaf758a95a2 100644 --- a/superset/queries/saved_queries/schemas.py +++ b/superset/queries/saved_queries/schemas.py @@ -18,7 +18,11 @@ from marshmallow.validate import Length openapi_spec_methods_override = { - "get": {"get": {"description": "Get a saved query",}}, + "get": { + "get": { + "description": "Get a saved query", + } + }, "get_list": { "get": { "description": "Get a list of saved queries, use Rison or JSON " diff --git a/superset/reports/commands/base.py b/superset/reports/commands/base.py index 3582767ef65f2..b17975a5a63e4 100644 --- a/superset/reports/commands/base.py +++ b/superset/reports/commands/base.py @@ -47,7 +47,7 @@ def validate(self) -> None: def validate_chart_dashboard( self, exceptions: List[ValidationError], update: bool = False ) -> None: - """ Validate chart or dashboard relation """ + """Validate chart or dashboard relation""" chart_id = self._properties.get("chart") dashboard_id = self._properties.get("dashboard") creation_method = self._properties.get("creation_method") diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 2a75d2e738ff3..c006d007c4c57 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -95,7 +95,9 @@ def __init__( self._execution_id = execution_id def set_state_and_log( - self, state: ReportState, error_message: Optional[str] = None, + self, + state: ReportState, + error_message: Optional[str] = None, ) -> None: """ Updates current ReportSchedule state and TS. If on final state writes the log @@ -104,7 +106,8 @@ def set_state_and_log( now_dttm = datetime.utcnow() self.set_state(state, now_dttm) self.create_log( - state, error_message=error_message, + state, + error_message=error_message, ) def set_state(self, state: ReportState, dttm: datetime) -> None: @@ -531,12 +534,14 @@ def next(self) -> None: if self.is_on_working_timeout(): exception_timeout = ReportScheduleWorkingTimeoutError() self.set_state_and_log( - ReportState.ERROR, error_message=str(exception_timeout), + ReportState.ERROR, + error_message=str(exception_timeout), ) raise exception_timeout exception_working = ReportSchedulePreviousWorkingError() self.set_state_and_log( - ReportState.WORKING, error_message=str(exception_working), + ReportState.WORKING, + error_message=str(exception_working), ) raise exception_working diff --git a/superset/reports/dao.py b/superset/reports/dao.py index 11947a349a650..079788fbd0e5f 100644 --- a/superset/reports/dao.py +++ b/superset/reports/dao.py @@ -227,7 +227,8 @@ def find_active(session: Optional[Session] = None) -> List[ReportSchedule]: @staticmethod def find_last_success_log( - report_schedule: ReportSchedule, session: Optional[Session] = None, + report_schedule: ReportSchedule, + session: Optional[Session] = None, ) -> Optional[ReportExecutionLog]: """ Finds last success execution log for a given report @@ -245,7 +246,8 @@ def find_last_success_log( @staticmethod def find_last_entered_working_log( - report_schedule: ReportSchedule, session: Optional[Session] = None, + report_schedule: ReportSchedule, + session: Optional[Session] = None, ) -> Optional[ReportExecutionLog]: """ Finds last success execution log for a given report @@ -264,7 +266,8 @@ def find_last_entered_working_log( @staticmethod def find_last_error_notification( - report_schedule: ReportSchedule, session: Optional[Session] = None, + report_schedule: ReportSchedule, + session: Optional[Session] = None, ) -> Optional[ReportExecutionLog]: """ Finds last error email sent diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py index 733903f8ea705..4076a2041cd99 100644 --- a/superset/reports/schemas.py +++ b/superset/reports/schemas.py @@ -203,7 +203,9 @@ class ReportSchedulePostSchema(Schema): default=ReportDataFormat.VISUALIZATION, validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)), ) - extra = fields.Dict(default=None,) + extra = fields.Dict( + default=None, + ) force_screenshot = fields.Boolean(default=False) @validates_schema diff --git a/superset/security/manager.py b/superset/security/manager.py index ac764d240549d..57f3f2f9586e3 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1307,7 +1307,7 @@ def can_access_based_on_dashboard(datasource: "BaseDatasource") -> bool: @staticmethod def _get_current_epoch_time() -> float: - """ This is used so the tests can mock time """ + """This is used so the tests can mock time""" return time.time() @staticmethod @@ -1376,7 +1376,8 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]: def get_guest_user_from_token(self, token: GuestToken) -> GuestUser: return self.guest_user_cls( - token=token, roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], + token=token, + roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: diff --git a/superset/sqllab/command.py b/superset/sqllab/command.py index 1c2674b060f19..bdc570b603433 100644 --- a/superset/sqllab/command.py +++ b/superset/sqllab/command.py @@ -170,7 +170,10 @@ def _validate_access(self, query: Query) -> None: except Exception as ex: raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex - def _set_query_limit_if_required(self, rendered_query: str,) -> None: + def _set_query_limit_if_required( + self, + rendered_query: str, + ) -> None: if self._is_required_to_set_limit(): self._set_query_limit(rendered_query) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index a280e303a3669..c2f96542898d9 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -100,7 +100,12 @@ def _raise_undefined_parameter_exception( extra={ "undefined_parameters": list(undefined_parameters), "template_parameters": execution_context.template_params, - "issue_codes": [{"code": 1006, "message": MSG_OF_1006,}], + "issue_codes": [ + { + "code": 1006, + "message": MSG_OF_1006, + } + ], }, ) diff --git a/superset/stats_logger.py b/superset/stats_logger.py index 29f77112e0239..4b869042a90df 100644 --- a/superset/stats_logger.py +++ b/superset/stats_logger.py @@ -107,6 +107,5 @@ def timing(self, key: str, value: float) -> None: def gauge(self, key: str, value: float) -> None: self.client.gauge(key, value) - except Exception: # pylint: disable=broad-except pass diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 9b17a465c9fcd..6a42d961e9d9d 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -71,7 +71,8 @@ def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( - job_metadata: Dict[str, Any], form_data: Dict[str, Any], + job_metadata: Dict[str, Any], + form_data: Dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel from superset.charts.data.commands.get_data_command import ChartDataCommand @@ -85,7 +86,9 @@ def load_chart_data_into_cache( cache_key = result["cache_key"] result_url = f"/api/v1/chart/data/{cache_key}" async_query_manager.update_job( - job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, + job_metadata, + async_query_manager.STATUS_DONE, + result_url=result_url, ) except SoftTimeLimitExceeded as ex: logger.warning("A timeout occurred while loading chart data, error: %s", ex) @@ -140,7 +143,9 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals set_and_log_cache(cache_manager.cache, cache_key, cache_value) result_url = f"/superset/explore_json/data/{cache_key}" async_query_manager.update_job( - job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, + job_metadata, + async_query_manager.STATUS_DONE, + result_url=result_url, ) except SoftTimeLimitExceeded as ex: logger.warning("A timeout occurred while loading explore json, error: %s", ex) diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py index 4b1debea4b1b2..f4c6365df36e6 100644 --- a/superset/tasks/scheduler.py +++ b/superset/tasks/scheduler.py @@ -61,7 +61,13 @@ def scheduler() -> None: active_schedule.working_timeout + app.config["ALERT_REPORTS_WORKING_SOFT_TIME_OUT_LAG"] ) - execute.apply_async((active_schedule.id, schedule,), **async_options) + execute.apply_async( + ( + active_schedule.id, + schedule, + ), + **async_options + ) @celery_app.task(name="reports.execute") @@ -70,7 +76,9 @@ def execute(report_schedule_id: int, scheduled_dttm: str) -> None: task_id = execute.request.id scheduled_dttm_ = parser.parse(scheduled_dttm) AsyncExecuteReportScheduleCommand( - task_id, report_schedule_id, scheduled_dttm_, + task_id, + report_schedule_id, + scheduled_dttm_, ).run() except ReportScheduleUnexpectedError as ex: logger.error( diff --git a/superset/tasks/slack_util.py b/superset/tasks/slack_util.py index 769f7da8a0ff9..2f44d92605272 100644 --- a/superset/tasks/slack_util.py +++ b/superset/tasks/slack_util.py @@ -54,7 +54,8 @@ def deliver_slack_msg( assert response["file"], str(response) # the uploaded file else: response = cast( - SlackResponse, client.chat_postMessage(channel=slack_channel, text=body), + SlackResponse, + client.chat_postMessage(channel=slack_channel, text=body), ) assert response["message"]["text"], str(response) logger.info("Sent the report to the slack %s", slack_channel) diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py index 5e4b8dfb755dc..94b83ddb372cf 100644 --- a/superset/tasks/thumbnails.py +++ b/superset/tasks/thumbnails.py @@ -72,5 +72,8 @@ def cache_dashboard_thumbnail( current_app.config["THUMBNAIL_SELENIUM_USER"], session=session ) screenshot.compute_and_cache( - user=user, cache=thumbnail_cache, force=force, thumb_size=thumb_size, + user=user, + cache=thumbnail_cache, + force=force, + thumb_size=thumb_size, ) diff --git a/superset/temporary_cache/api.py b/superset/temporary_cache/api.py index b1c5999630b68..e91a2886691f4 100644 --- a/superset/temporary_cache/api.py +++ b/superset/temporary_cache/api.py @@ -63,10 +63,12 @@ class TemporaryCacheRestApi(BaseApi, ABC): def add_apispec_components(self, api_spec: APISpec) -> None: try: api_spec.components.schema( - TemporaryCachePostSchema.__name__, schema=TemporaryCachePostSchema, + TemporaryCachePostSchema.__name__, + schema=TemporaryCachePostSchema, ) api_spec.components.schema( - TemporaryCachePutSchema.__name__, schema=TemporaryCachePutSchema, + TemporaryCachePutSchema.__name__, + schema=TemporaryCachePutSchema, ) except DuplicateComponentNameError: pass diff --git a/superset/temporary_cache/commands/update.py b/superset/temporary_cache/commands/update.py index 584e16690b61f..92af8c14f20af 100644 --- a/superset/temporary_cache/commands/update.py +++ b/superset/temporary_cache/commands/update.py @@ -29,7 +29,8 @@ class UpdateTemporaryCacheCommand(BaseCommand, ABC): def __init__( - self, cmd_params: CommandParameters, + self, + cmd_params: CommandParameters, ): self._parameters = cmd_params diff --git a/superset/utils/cache.py b/superset/utils/cache.py index c10f296e1cac4..02a4cdfecc0ee 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -94,7 +94,8 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused- def memoized_func( - key: Callable[..., str] = view_cache_key, cache: Cache = cache_manager.cache, + key: Callable[..., str] = view_cache_key, + cache: Cache = cache_manager.cache, ) -> Callable[..., Any]: """Use this decorator to cache functions that have predefined first arg. diff --git a/superset/utils/core.py b/superset/utils/core.py index b6ae3272cfed5..fbfbbf52f5699 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -357,7 +357,6 @@ def __init__(self, **args: Any) -> None: } } - except NameError: pass @@ -513,7 +512,9 @@ def format_timedelta(time_delta: timedelta) -> str: return str(time_delta) -def base_json_conv(obj: Any,) -> Any: # pylint: disable=inconsistent-return-statements +def base_json_conv( # pylint: disable=inconsistent-return-statements + obj: Any, +) -> Any: if isinstance(obj, memoryview): obj = obj.tobytes() if isinstance(obj, np.int64): @@ -1014,7 +1015,8 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, def simple_filter_to_adhoc( - filter_clause: QueryObjectFilterClause, clause: str = "where", + filter_clause: QueryObjectFilterClause, + clause: str = "where", ) -> AdhocFilterClause: result: AdhocFilterClause = { "clause": clause.upper(), @@ -1277,7 +1279,8 @@ def get_metric_name( def get_column_names( - columns: Optional[Sequence[Column]], verbose_map: Optional[Dict[str, Any]] = None, + columns: Optional[Sequence[Column]], + verbose_map: Optional[Dict[str, Any]] = None, ) -> List[str]: return [ column @@ -1287,7 +1290,8 @@ def get_column_names( def get_metric_names( - metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None, + metrics: Optional[Sequence[Metric]], + verbose_map: Optional[Dict[str, Any]] = None, ) -> List[str]: return [ metric @@ -1297,7 +1301,8 @@ def get_metric_names( def get_first_metric_name( - metrics: Optional[Sequence[Metric]], verbose_map: Optional[Dict[str, Any]] = None, + metrics: Optional[Sequence[Metric]], + verbose_map: Optional[Dict[str, Any]] = None, ) -> Optional[str]: metric_labels = get_metric_names(metrics, verbose_map) return metric_labels[0] if metric_labels else None @@ -1571,7 +1576,8 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: def extract_dataframe_dtypes( - df: pd.DataFrame, datasource: Optional["BaseDatasource"] = None, + df: pd.DataFrame, + datasource: Optional["BaseDatasource"] = None, ) -> List[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" @@ -1632,7 +1638,8 @@ def is_test() -> bool: def get_time_filter_status( - datasource: "BaseDatasource", applied_time_extras: Dict[str, str], + datasource: "BaseDatasource", + applied_time_extras: Dict[str, str], ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: temporal_columns = {col.column_name for col in datasource.columns if col.is_dttm} applied: List[Dict[str, str]] = [] @@ -1786,7 +1793,10 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool: return False -def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: +def apply_max_row_limit( + limit: int, + max_limit: Optional[int] = None, +) -> int: """ Override row limit if max global limit is defined diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 76a13696952ef..cc0693770bfa7 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -99,7 +99,8 @@ def dttm_from_timetuple(date_: struct_time) -> datetime: def get_past_or_future( - human_readable: Optional[str], source_time: Optional[datetime] = None, + human_readable: Optional[str], + source_time: Optional[datetime] = None, ) -> datetime: cal = parsedatetime.Calendar() source_dttm = dttm_from_timetuple( @@ -109,7 +110,8 @@ def get_past_or_future( def parse_human_timedelta( - human_readable: Optional[str], source_time: Optional[datetime] = None, + human_readable: Optional[str], + source_time: Optional[datetime] = None, ) -> timedelta: """ Returns ``datetime.timedelta`` from natural language time deltas @@ -135,7 +137,8 @@ def parse_past_timedelta( or datetime.timedelta(365). """ return -parse_human_timedelta( - delta_str if delta_str.startswith("-") else f"-{delta_str}", source_time, + delta_str if delta_str.startswith("-") else f"-{delta_str}", + source_time, ) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index 0fa2a6d177a06..7c93764f691ba 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -156,7 +156,8 @@ def _re_encrypt_row( raise Exception from exc re_encrypted_columns[column_name] = encrypted_type.process_bind_param( - unencrypted_value, self._dialect, + unencrypted_value, + self._dialect, ) set_cols = ",".join( diff --git a/superset/utils/log.py b/superset/utils/log.py index ecee213ebef29..4eafb669328fc 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -195,7 +195,10 @@ def log_with_context( # pylint: disable=too-many-locals @contextmanager def log_context( - self, action: str, object_ref: Optional[str] = None, log_to_statsd: bool = True, + self, + action: str, + object_ref: Optional[str] = None, + log_to_statsd: bool = True, ) -> Iterator[Callable[..., None]]: """ Log an event with additional information from the request context. diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py index 778f2a6b1de04..01347f90f1c25 100644 --- a/superset/utils/machine_auth.py +++ b/superset/utils/machine_auth.py @@ -40,10 +40,14 @@ def __init__( # overridden via config, as opposed to the entire provider implementation self._auth_webdriver_func_override = auth_webdriver_func_override - def authenticate_webdriver(self, driver: WebDriver, user: "User",) -> WebDriver: + def authenticate_webdriver( + self, + driver: WebDriver, + user: "User", + ) -> WebDriver: """ - Default AuthDriverFuncType type that sets a session cookie flask-login style - :return: The WebDriver passed in (fluent) + Default AuthDriverFuncType type that sets a session cookie flask-login style + :return: The WebDriver passed in (fluent) """ # Short-circuit this method if we have an override configured if self._auth_webdriver_func_override: diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 1c6515804b7fa..ea83f7398251f 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -112,7 +112,9 @@ def get_type_generator( # pylint: disable=too-many-return-statements,too-many-b if isinstance(sqltype, sqlalchemy.sql.sqltypes.TIME): return lambda: time( - random.randrange(24), random.randrange(60), random.randrange(60), + random.randrange(24), + random.randrange(60), + random.randrange(60), ) if isinstance( diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py index d2bd5761176fa..b94f048e5cd62 100644 --- a/superset/utils/pandas_postprocessing/cum.py +++ b/superset/utils/pandas_postprocessing/cum.py @@ -28,7 +28,11 @@ @validate_column_args("columns") -def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame: +def cum( + df: DataFrame, + operator: str, + columns: Dict[str, str], +) -> DataFrame: """ Calculate cumulative sum/product/min/max for select columns. diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index a348801e31f3b..49f250ec1c9b9 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -22,7 +22,10 @@ ) -def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame: +def flatten( + df: pd.DataFrame, + reset_index: bool = True, +) -> pd.DataFrame: """ Convert N-dimensional DataFrame to a flat DataFrame diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py index 8ea75d2450293..33a27c2df4074 100644 --- a/superset/utils/pandas_postprocessing/geography.py +++ b/superset/utils/pandas_postprocessing/geography.py @@ -50,7 +50,10 @@ def geohash_decode( def geohash_encode( - df: DataFrame, geohash: str, longitude: str, latitude: str, + df: DataFrame, + geohash: str, + longitude: str, + latitude: str, ) -> DataFrame: """ Encode longitude and latitude into geohash @@ -65,7 +68,8 @@ def geohash_encode( encode_df = df[[latitude, longitude]] encode_df.columns = ["latitude", "longitude"] encode_df["geohash"] = encode_df.apply( - lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), axis=1, + lambda row: geohash_lib.encode(row["latitude"], row["longitude"]), + axis=1, ) return _append_columns(df, encode_df, {"geohash": geohash}) except ValueError as ex: diff --git a/superset/utils/pandas_postprocessing/prophet.py b/superset/utils/pandas_postprocessing/prophet.py index 8a85e581b52af..d66298b1790cc 100644 --- a/superset/utils/pandas_postprocessing/prophet.py +++ b/superset/utils/pandas_postprocessing/prophet.py @@ -114,7 +114,10 @@ def prophet( # pylint: disable=too-many-arguments raise InvalidPostProcessingError(_("Time grain missing")) if time_grain not in PROPHET_TIME_GRAIN_MAP: raise InvalidPostProcessingError( - _("Unsupported time grain: %(time_grain)s", time_grain=time_grain,) + _( + "Unsupported time grain: %(time_grain)s", + time_grain=time_grain, + ) ) freq = PROPHET_TIME_GRAIN_MAP[time_grain] # check type at runtime due to marhsmallow schema not being able to handle diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index dc48cd1145c87..46f5a0c50529c 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -148,7 +148,8 @@ def wrapped(df: DataFrame, **options: Any) -> Any: def _get_aggregate_funcs( - df: DataFrame, aggregates: Dict[str, Dict[str, Any]], + df: DataFrame, + aggregates: Dict[str, Dict[str, Any]], ) -> Dict[str, NamedAgg]: """ Converts a set of aggregate config objects into functions that pandas can use as @@ -170,7 +171,10 @@ def _get_aggregate_funcs( ) if "operator" not in agg_obj: raise InvalidPostProcessingError( - _("Operator undefined for aggregator: %(name)s", name=name,) + _( + "Operator undefined for aggregator: %(name)s", + name=name, + ) ) operator = agg_obj["operator"] if callable(operator): @@ -179,7 +183,10 @@ def _get_aggregate_funcs( func = NUMPY_FUNCTIONS.get(operator) if not func: raise InvalidPostProcessingError( - _("Invalid numpy function: %(operator)s", operator=operator,) + _( + "Invalid numpy function: %(operator)s", + operator=operator, + ) ) options = agg_obj.get("options", {}) aggfunc = partial(func, **options) diff --git a/superset/utils/profiler.py b/superset/utils/profiler.py index a4710bb24341e..d17b69d8bdc9a 100644 --- a/superset/utils/profiler.py +++ b/superset/utils/profiler.py @@ -35,7 +35,9 @@ class SupersetProfiler: # pylint: disable=too-few-public-methods """ def __init__( - self, app: Callable[[Any, Any], Any], interval: float = 0.0001, + self, + app: Callable[[Any, Any], Any], + interval: float = 0.0001, ): self.app = app self.interval = interval diff --git a/superset/views/base.py b/superset/views/base.py index ae1f7a66b4f1f..863ca2f84ab67 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -629,7 +629,7 @@ def apply(self, query: Query, value: Any) -> Query: ) -class CsvResponse(Response): # pylint: disable=too-many-ancestors +class CsvResponse(Response): """ Override Response to take into account csv encoding from config.py """ diff --git a/superset/views/core.py b/superset/views/core.py index a3356e77aa08d..5f6160642d85a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -319,7 +319,9 @@ def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-u def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, dar.datasource_id, session, + dar.datasource_type, + dar.datasource_id, + session, ) if not datasource or security_manager.can_access_datasource(datasource): # Dataset does not exist anymore @@ -627,7 +629,8 @@ def explore_json( and not security_manager.can_access("can_csv", "Superset") ): return json_error_response( - _("You don't have the rights to ") + _("download as csv"), status=403, + _("You don't have the rights to ") + _("download as csv"), + status=403, ) form_data = get_form_data()[0] @@ -943,7 +946,9 @@ def filter( # pylint: disable=no-self-use """ # TODO: Cache endpoint by user, datasource and column datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session, + datasource_type, + datasource_id, + db.session, ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -1426,7 +1431,10 @@ def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: try: security_manager.raise_for_user_activity_access(user_id) except SupersetSecurityException as ex: - return json_error_response(ex.message, status=403,) + return json_error_response( + ex.message, + status=403, + ) return None @api @@ -1449,7 +1457,8 @@ def recent_activity( # pylint: disable=too-many-locals has_subject_title = or_( and_( - Dashboard.dashboard_title is not None, Dashboard.dashboard_title != "", + Dashboard.dashboard_title is not None, + Dashboard.dashboard_title != "", ), and_(Slice.slice_name is not None, Slice.slice_name != ""), ) @@ -1483,7 +1492,10 @@ def recent_activity( # pylint: disable=too-many-locals Slice.slice_name, ) .outerjoin(Dashboard, Dashboard.id == subqry.c.dashboard_id) - .outerjoin(Slice, Slice.id == subqry.c.slice_id,) + .outerjoin( + Slice, + Slice.id == subqry.c.slice_id, + ) .filter(has_subject_title) .order_by(subqry.c.dttm.desc()) .limit(limit) @@ -2005,7 +2017,8 @@ def dashboard( @has_access @expose("/dashboard/p//", methods=["GET"]) def dashboard_permalink( # pylint: disable=no-self-use - self, key: str, + self, + key: str, ) -> FlaskResponse: try: value = GetDashboardPermalinkCommand(g.user, key).run() @@ -2579,7 +2592,8 @@ def _set_http_status_into_Sql_lab_exception(ex: SqlLabException) -> None: ex.status = 403 def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use - self, command_result: CommandResult, + self, + command_result: CommandResult, ) -> FlaskResponse: status_code = 200 @@ -2670,7 +2684,9 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session, + datasource_type, + datasource_id, + db.session, ) # Check if datasource exists if not datasource: diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index 471d168e1ec64..49bdf76edcb0b 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -155,7 +155,8 @@ def embedded( login_manager.reload_user(AnonymousUserMixin()) add_extra_log_payload( - dashboard_id=dashboard_id_or_slug, dashboard_version="v2", + dashboard_id=dashboard_id_or_slug, + dashboard_version="v2", ) bootstrap_data = { diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 659a1be78d8e3..bb2e018994e44 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -275,11 +275,13 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: flash(message, "danger") return redirect("/exceltodatabaseview/form") - uploaded_tmp_file_path = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with - dir=app.config["UPLOAD_FOLDER"], - suffix=os.path.splitext(form.excel_file.data.filename)[1].lower(), - delete=False, - ).name + uploaded_tmp_file_path = ( + tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with + dir=app.config["UPLOAD_FOLDER"], + suffix=os.path.splitext(form.excel_file.data.filename)[1].lower(), + delete=False, + ).name + ) try: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index 162f35b4730b1..64b2b854bb148 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -44,7 +44,9 @@ class ExternalMetadataSchema(Schema): # pylint: disable=no-self-use,unused-argument @post_load def normalize( - self, data: ExternalMetadataParams, **kwargs: Any, + self, + data: ExternalMetadataParams, + **kwargs: Any, ) -> ExternalMetadataParams: return ExternalMetadataParams( datasource_type=data["datasource_type"], diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 7e1ffa0468e90..2504f458eabc2 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -161,8 +161,8 @@ def external_metadata( def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: """Gets table metadata from the source system and SQLAlchemy inspector""" try: - params: ExternalMetadataParams = ( - ExternalMetadataSchema().load(kwargs.get("rison")) + params: ExternalMetadataParams = ExternalMetadataSchema().load( + kwargs.get("rison") ) except ValidationError as err: return json_error_response(str(err), status=400) diff --git a/superset/views/users/api.py b/superset/views/users/api.py index 8945be9b0c55b..584e8145ec391 100644 --- a/superset/views/users/api.py +++ b/superset/views/users/api.py @@ -24,7 +24,7 @@ class CurrentUserRestApi(BaseApi): - """ An api to get information about the current user """ + """An api to get information about the current user""" resource_name = "me" openapi_spec_tag = "Current User" diff --git a/superset/views/utils.py b/superset/views/utils.py index 19c9a2eaf05af..202f87d996976 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -418,7 +418,9 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool: return obj and user in obj.owners -def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]: +def check_resource_permissions( + check_perms: Callable[..., Any], +) -> Callable[..., Any]: """ A decorator for checking permissions on a request using the passed-in function. """ diff --git a/superset/viz.py b/superset/viz.py index 7544af5078059..e83e0127775cb 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1086,7 +1086,7 @@ def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=too-many-loc v = query_obj[DTTM_ALIAS] if hasattr(v, "value"): v = v.value - values[str(v / 10 ** 9)] = query_obj.get(metric) + values[str(v / 10**9)] = query_obj.get(metric) data[metric] = values try: @@ -1945,7 +1945,12 @@ def get_data(self, df: pd.DataFrame) -> VizData: source, target = get_column_names(self.groupby) (value,) = self.metric_labels df.rename( - columns={source: "source", target: "target", value: "value",}, inplace=True, + columns={ + source: "source", + target: "target", + value: "value", + }, + inplace=True, ) df["source"] = df["source"].astype(str) df["target"] = df["target"].astype(str) diff --git a/tests/common/query_context_generator.py b/tests/common/query_context_generator.py index d97b270002f5c..8fddeb92ffda9 100644 --- a/tests/common/query_context_generator.py +++ b/tests/common/query_context_generator.py @@ -45,7 +45,9 @@ QUERY_OBJECTS: Dict[str, Dict[str, object]] = { "birth_names": query_birth_names, # `:suffix` are overrides only - "birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],}, + "birth_names:include_time": { + "groupby": [DTTM_ALIAS, "name"], + }, "birth_names:orderby_dup_alias": { "metrics": [ { @@ -93,7 +95,9 @@ ], ], }, - "birth_names:only_orderby_has_metric": {"metrics": [],}, + "birth_names:only_orderby_has_metric": { + "metrics": [], + }, } ANNOTATION_LAYERS = { @@ -182,17 +186,27 @@ # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html "options": {"q": 25, "interpolation": "lower"}, }, - "median": {"operator": "median", "column": "sum__num",}, + "median": { + "operator": "median", + "column": "sum__num", + }, }, }, }, - {"operation": "sort", "options": {"columns": {"q1": False, "name": True},},}, + { + "operation": "sort", + "options": { + "columns": {"q1": False, "name": True}, + }, + }, ] } def get_query_object( - query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, + query_name: str, + add_postprocessing_operations: bool, + add_time_offsets: bool, ) -> Dict[str, Any]: if query_name not in QUERY_OBJECTS: raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") @@ -247,7 +261,9 @@ def generate( "datasource": {"id": table.id, "type": table.type}, "queries": [ get_query_object( - query_name, add_postprocessing_operations, add_time_offsets, + query_name, + add_postprocessing_operations, + add_time_offsets, ) ], "result_type": ChartDataResultType.FULL, diff --git a/tests/conftest.py b/tests/conftest.py index 6350d3235c5b6..92f9b10d955ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,7 +71,9 @@ def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: @fixture(scope="session") -def pandas_loader_configuration(support_datetime_type,) -> PandasLoaderConfigurations: +def pandas_loader_configuration( + support_datetime_type, +) -> PandasLoaderConfigurations: return PandasLoaderConfigurations.make_from_dict( {SUPPORT_DATETIME_TYPE: support_datetime_type} ) diff --git a/tests/fixtures/birth_names.py b/tests/fixtures/birth_names.py index bcf2a5aa91b8a..5a0135b456f25 100644 --- a/tests/fixtures/birth_names.py +++ b/tests/fixtures/birth_names.py @@ -41,7 +41,8 @@ def birth_names_data_generator() -> BirthNamesGenerator: @fixture(scope="session") def birth_names_table_factory( - birth_names_data_generator: BirthNamesGenerator, support_datetime_type: bool, + birth_names_data_generator: BirthNamesGenerator, + support_datetime_type: bool, ) -> Callable[[], Table]: def _birth_names_table_factory() -> Table: return BirthNamesMetaDataFactory(support_datetime_type).make_table( diff --git a/tests/integration_tests/annotation_layers/fixtures.py b/tests/integration_tests/annotation_layers/fixtures.py index 0b9e19e21fb70..52cb8ca2abf41 100644 --- a/tests/integration_tests/annotation_layers/fixtures.py +++ b/tests/integration_tests/annotation_layers/fixtures.py @@ -39,7 +39,10 @@ def get_end_dttm(annotation_id: int) -> datetime: def _insert_annotation_layer(name: str = "", descr: str = "") -> AnnotationLayer: - annotation_layer = AnnotationLayer(name=name, descr=descr,) + annotation_layer = AnnotationLayer( + name=name, + descr=descr, + ) db.session.add(annotation_layer) db.session.commit() return annotation_layer diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index f68ad85cb5c21..802684ba3bd07 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -256,7 +256,11 @@ def test_run_async_query_cta_config(setup_sqllab, ctas_method): return tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}" result = run_sql( - QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=tmp_table_name, + QUERY, + cta=True, + ctas_method=ctas_method, + async_=True, + tmp_table=tmp_table_name, ) query = wait_for_success(result) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index daff4f14f6733..3c92caceead73 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -917,7 +917,11 @@ def test_admin_gets_filtered_energy_slices(self): # test filtering on datasource_name arguments = { "filters": [ - {"col": "slice_name", "opr": "chart_all_text", "value": "energy",} + { + "col": "slice_name", + "opr": "chart_all_text", + "value": "energy", + } ], "keys": ["none"], "columns": ["slice_name"], @@ -933,7 +937,13 @@ def test_admin_gets_filtered_energy_slices(self): @pytest.mark.usefixtures("create_certified_charts") def test_gets_certified_charts_filter(self): arguments = { - "filters": [{"col": "id", "opr": "chart_is_certified", "value": True,}], + "filters": [ + { + "col": "id", + "opr": "chart_is_certified", + "value": True, + } + ], "keys": ["none"], "columns": ["slice_name"], } @@ -948,7 +958,13 @@ def test_gets_certified_charts_filter(self): @pytest.mark.usefixtures("create_charts") def test_gets_not_certified_charts_filter(self): arguments = { - "filters": [{"col": "id", "opr": "chart_is_certified", "value": False,}], + "filters": [ + { + "col": "id", + "opr": "chart_is_certified", + "value": False, + } + ], "keys": ["none"], "columns": ["slice_name"], } @@ -965,7 +981,11 @@ def test_user_gets_none_filtered_energy_slices(self): # test filtering on datasource_name arguments = { "filters": [ - {"col": "slice_name", "opr": "chart_all_text", "value": "energy",} + { + "col": "slice_name", + "opr": "chart_all_text", + "value": "energy", + } ], "keys": ["none"], "columns": ["slice_name"], diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index c45c8d5064eaa..73425fb58f68c 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -129,7 +129,8 @@ def assert_row_count(rv: Response, expected_row_count: int): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.common.query_context_factory.config", {**app.config, "ROW_LIMIT": 7}, + "superset.common.query_context_factory.config", + {**app.config, "ROW_LIMIT": 7}, ) def test_without_row_limit__row_count_as_default_row_limit(self): # arrange @@ -161,7 +162,8 @@ def test_as_samples_without_row_limit__row_count_as_default_samples_row_limit(se @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10}, + "superset.utils.core.current_app.config", + {**app.config, "SQL_MAX_ROW": 10}, ) def test_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row(self): # arrange @@ -176,7 +178,8 @@ def test_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, + "superset.utils.core.current_app.config", + {**app.config, "SQL_MAX_ROW": 5}, ) def test_as_samples_with_row_limit_bigger_then_sql_max_row__rowcount_as_sql_max_row( self, @@ -334,7 +337,9 @@ def test_chart_data_applied_time_extras(self): ) self.assertEqual( data["result"][0]["rejected_filters"], - [{"column": "__time_origin", "reason": "not_druid_datasource"},], + [ + {"column": "__time_origin", "reason": "not_druid_datasource"}, + ], ) expected_row_count = self.get_expected_row_count("client_id_2") self.assertEqual(data["result"][0]["rowcount"], expected_row_count) @@ -384,7 +389,8 @@ def test_chart_data_dttm_filter(self): dttm_col = col if dttm_col: dttm_expression = table.database.db_engine_spec.convert_dttm( - dttm_col.type, dttm, + dttm_col.type, + dttm, ) self.assertIn(dttm_expression, result["query"]) else: @@ -479,7 +485,10 @@ def test_with_orderby_parameter_with_second_query__400(self): self.query_context_payload["queries"][0]["filters"] = [] self.query_context_payload["queries"][0]["orderby"] = [ [ - {"expressionType": "SQL", "sqlExpression": "sum__num; select 1, 1",}, + { + "expressionType": "SQL", + "sqlExpression": "sum__num; select 1, 1", + }, True, ], ] @@ -736,7 +745,11 @@ def test_with_virtual_table_with_colons_as_datasource(self): } ] request_payload["queries"][0]["filters"] = [ - {"col": "foo", "op": "!=", "val": ":qwerty:",} + { + "col": "foo", + "op": "!=", + "val": ":qwerty:", + } ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") @@ -783,7 +796,11 @@ def test_chart_data_get(self): "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", "granularity": "ds", "filters": [], - "extras": {"having": "", "having_druid": [], "where": "",}, + "extras": { + "having": "", + "having_druid": [], + "where": "", + }, "applied_time_extras": {}, "columns": ["gender"], "metrics": ["sum__num"], @@ -868,7 +885,9 @@ def mock_run(self, **kwargs): return orig_run(self, force_cached=False) with mock.patch.object(ChartDataCommand, "run", new=mock_run): - rv = self.client.get(f"{CHART_DATA_URI}/test-cache-key",) + rv = self.client.get( + f"{CHART_DATA_URI}/test-cache-key", + ) self.assertEqual(rv.status_code, 401) diff --git a/tests/integration_tests/cli_tests.py b/tests/integration_tests/cli_tests.py index 7426d90ea88af..f69efc6253943 100644 --- a/tests/integration_tests/cli_tests.py +++ b/tests/integration_tests/cli_tests.py @@ -373,7 +373,9 @@ def test_import_datasets_sync_argument_columns_metrics( assert response.exit_code == 0 expected_contents = {"dataset.yaml": "hello: world"} import_datasets_command.assert_called_with( - expected_contents, sync_columns=True, sync_metrics=True, + expected_contents, + sync_columns=True, + sync_metrics=True, ) @@ -408,7 +410,9 @@ def test_import_datasets_sync_argument_columns( assert response.exit_code == 0 expected_contents = {"dataset.yaml": "hello: world"} import_datasets_command.assert_called_with( - expected_contents, sync_columns=True, sync_metrics=False, + expected_contents, + sync_columns=True, + sync_metrics=False, ) @@ -443,7 +447,9 @@ def test_import_datasets_sync_argument_metrics( assert response.exit_code == 0 expected_contents = {"dataset.yaml": "hello: world"} import_datasets_command.assert_called_with( - expected_contents, sync_columns=False, sync_metrics=True, + expected_contents, + sync_columns=False, + sync_metrics=True, ) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 83ea63daf0b70..ca0b374e9734c 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -974,7 +974,8 @@ def test_explore_json(self): } self.login(username="admin") rv = self.client.post( - "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, + "/superset/explore_json/", + data={"form_data": json.dumps(form_data)}, ) data = json.loads(rv.data.decode("utf-8")) @@ -1049,7 +1050,8 @@ def test_explore_json_dist_bar_order(self): self.login(username="admin") rv = self.client.post( - "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, + "/superset/explore_json/", + data={"form_data": json.dumps(form_data)}, ) data = json.loads(rv.data.decode("utf-8")) @@ -1096,7 +1098,8 @@ def test_explore_json_async(self): async_query_manager.init_app(app) self.login(username="admin") rv = self.client.post( - "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, + "/superset/explore_json/", + data={"form_data": json.dumps(form_data)}, ) data = json.loads(rv.data.decode("utf-8")) keys = list(data.keys()) @@ -1634,7 +1637,8 @@ def test_stop_query_not_implemented( mock_superset_db_session.query().filter_by().one().return_value = query_mock mock_sql_lab_cancel_query.return_value = False rv = self.client.post( - "/superset/stop_query/", data={"form_data": json.dumps(form_data)}, + "/superset/stop_query/", + data={"form_data": json.dumps(form_data)}, ) assert rv.status_code == 422 diff --git a/tests/integration_tests/css_templates/api_tests.py b/tests/integration_tests/css_templates/api_tests.py index 8f7e580e8ce14..b28cca955ca8d 100644 --- a/tests/integration_tests/css_templates/api_tests.py +++ b/tests/integration_tests/css_templates/api_tests.py @@ -34,7 +34,10 @@ class TestCssTemplateApi(SupersetTestCase): def insert_css_template( - self, template_name: str, css: str, created_by_username: str = "admin", + self, + template_name: str, + css: str, + created_by_username: str = "admin", ) -> CssTemplate: admin = self.get_user(created_by_username) css_template = CssTemplate( diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index c50b75b8ae152..fa6efd60b4dac 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -30,7 +30,9 @@ def get_table( - table_name: str, database: Database, schema: Optional[str] = None, + table_name: str, + database: Database, + schema: Optional[str] = None, ): schema = schema or get_example_default_schema() table_source = ConnectorRegistry.sources["table"] diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 938de31414393..7fc8c2e2518fc 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -329,7 +329,11 @@ def test_get_dashboard(self): "changed_by_name": "", "changed_by_url": "", "charts": [], - "created_by": {"id": 1, "first_name": "admin", "last_name": "user",}, + "created_by": { + "id": 1, + "first_name": "admin", + "last_name": "user", + }, "id": dashboard.id, "css": "", "dashboard_title": "title", @@ -356,7 +360,10 @@ def test_get_dashboard(self): self.assertIn("changed_on_delta_humanized", data["result"]) for key, value in data["result"].items(): # We can't assert timestamp values - if key not in ("changed_on", "changed_on_delta_humanized",): + if key not in ( + "changed_on", + "changed_on_delta_humanized", + ): self.assertEqual(value, expected_result[key]) # rollback changes db.session.delete(dashboard) @@ -623,7 +630,13 @@ def test_get_dashboards_not_favorite_filter(self): @pytest.mark.usefixtures("create_dashboards") def test_gets_certified_dashboards_filter(self): arguments = { - "filters": [{"col": "id", "opr": "dashboard_is_certified", "value": True,}], + "filters": [ + { + "col": "id", + "opr": "dashboard_is_certified", + "value": True, + } + ], "keys": ["none"], "columns": ["dashboard_title"], } @@ -639,7 +652,11 @@ def test_gets_certified_dashboards_filter(self): def test_gets_not_certified_dashboards_filter(self): arguments = { "filters": [ - {"col": "id", "opr": "dashboard_is_certified", "value": False,} + { + "col": "id", + "opr": "dashboard_is_certified", + "value": False, + } ], "keys": ["none"], "columns": ["dashboard_title"], @@ -1135,7 +1152,12 @@ def test_update_dashboard_chart_owners(self): slices.append(db.session.query(Slice).filter_by(slice_name="Trends").first()) slices.append(db.session.query(Slice).filter_by(slice_name="Boys").first()) - dashboard = self.insert_dashboard("title1", "slug1", [admin.id], slices=slices,) + dashboard = self.insert_dashboard( + "title1", + "slug1", + [admin.id], + slices=slices, + ) self.login(username="admin") uri = f"api/v1/dashboard/{dashboard.id}" dashboard_data = {"owners": [user_alpha1.id, user_alpha2.id]} diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py index 0d36a0a593e5b..7be6f367dd6b9 100644 --- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py @@ -37,7 +37,9 @@ class TestGetFilterSetsApi: def test_with_dashboard_not_exists__404( - self, not_exists_dashboard_id: int, client: FlaskClient[Any], + self, + not_exists_dashboard_id: int, + client: FlaskClient[Any], ): # arrange login(client, "admin") diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index 7a8a2906bcdfb..33186131d559f 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -61,7 +61,8 @@ def permalink_salt() -> Iterator[str]: yield salt namespace = get_uuid_namespace(salt) db.session.query(KeyValueEntry).filter_by( - resource=KeyValueResource.APP, uuid=uuid3(namespace, key), + resource=KeyValueResource.APP, + uuid=uuid3(namespace, key), ) db.session.commit() diff --git a/tests/integration_tests/dashboards/security/security_dataset_tests.py b/tests/integration_tests/dashboards/security/security_dataset_tests.py index 01bce014afdc5..34a5fedad0bfb 100644 --- a/tests/integration_tests/dashboards/security/security_dataset_tests.py +++ b/tests/integration_tests/dashboards/security/security_dataset_tests.py @@ -91,11 +91,14 @@ def test_get_dashboards__users_are_dashboards_owners(self): username = "gamma" user = security_manager.find_user(username) my_owned_dashboard = create_dashboard_to_db( - dashboard_title="My Dashboard", published=False, owners=[user], + dashboard_title="My Dashboard", + published=False, + owners=[user], ) not_my_owned_dashboard = create_dashboard_to_db( - dashboard_title="Not My Dashboard", published=False, + dashboard_title="Not My Dashboard", + published=False, ) self.login(user.username) diff --git a/tests/integration_tests/dashboards/security/security_rbac_tests.py b/tests/integration_tests/dashboards/security/security_rbac_tests.py index 62b0e1b4c2755..d425c0e71118f 100644 --- a/tests/integration_tests/dashboards/security/security_rbac_tests.py +++ b/tests/integration_tests/dashboards/security/security_rbac_tests.py @@ -41,7 +41,8 @@ @mock.patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", DASHBOARD_RBAC=True, + "superset.extensions.feature_flag_manager._feature_flags", + DASHBOARD_RBAC=True, ) class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity): def test_get_dashboard_view__admin_can_access(self): @@ -293,7 +294,10 @@ def test_get_dashboards_api__user_get_only_published_permitted_dashboards(self): # assert self.assert_dashboards_api_response( - response, len(published_dashboards), published_dashboards, draft_dashboards, + response, + len(published_dashboards), + published_dashboards, + draft_dashboards, ) # post @@ -337,7 +341,10 @@ def test_get_dashboards_api__public_user_get_only_published_permitted_dashboards # assert self.assert_dashboards_api_response( - response, len(published_dashboards), published_dashboards, draft_dashboards, + response, + len(published_dashboards), + published_dashboards, + draft_dashboards, ) # post diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 928f3d595730d..4f29600bdabb7 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -458,7 +458,8 @@ def test_create_database_uri_validate(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertIn( - "Invalid connection string", response["message"]["sqlalchemy_uri"][0], + "Invalid connection string", + response["message"]["sqlalchemy_uri"][0], ) @mock.patch( @@ -623,7 +624,8 @@ def test_update_database_uri_validate(self): response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertIn( - "Invalid connection string", response["message"]["sqlalchemy_uri"][0], + "Invalid connection string", + response["message"]["sqlalchemy_uri"][0], ) db.session.delete(test_database) @@ -1038,7 +1040,9 @@ def test_test_connection_unsafe_uri(self): @mock.patch( "superset.databases.commands.test_connection.DatabaseDAO.build_db_for_connection_test", ) - @mock.patch("superset.databases.commands.test_connection.event_logger",) + @mock.patch( + "superset.databases.commands.test_connection.event_logger", + ) def test_test_connection_failed_invalid_hostname( self, mock_event_logger, mock_build_db ): @@ -1419,7 +1423,9 @@ def test_import_database_masked_password_provided(self): db.session.delete(database) db.session.commit() - @mock.patch("superset.db_engine_specs.base.BaseEngineSpec.get_function_names",) + @mock.patch( + "superset.db_engine_specs.base.BaseEngineSpec.get_function_names", + ) def test_function_names(self, mock_get_function_names): example_db = get_example_database() if example_db.backend in {"hive", "presto"}: diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index fe493a5504aed..83849931f4d76 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -107,7 +107,10 @@ def create_virtual_datasets(self): for table_name in self.fixture_virtual_table_names: datasets.append( self.insert_dataset( - table_name, [admin.id], main_db, "SELECT * from ab_view_menu;", + table_name, + [admin.id], + main_db, + "SELECT * from ab_view_menu;", ) ) yield datasets @@ -286,7 +289,10 @@ def pg_test_query_parameter(query_parameter, expected_response): ) datasets.append( self.insert_dataset( - "columns", [], get_main_database(), schema="information_schema", + "columns", + [], + get_main_database(), + schema="information_schema", ) ) schema_values = [ @@ -570,7 +576,12 @@ def test_create_dataset_validate_view_exists( """ mock_get_columns.return_value = [ - {"name": "col", "type": "VARCHAR", "type_generic": None, "is_dttm": None,} + { + "name": "col", + "type": "VARCHAR", + "type_generic": None, + "is_dttm": None, + } ] mock_has_table_by_name.return_value = False diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index eecc0501b3ffe..1428a20c48352 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -158,7 +158,11 @@ def test_external_metadata_by_name_from_sqla_inspector(self): # No databases found params = prison.dumps( - {"datasource_type": "table", "database_name": "foo", "table_name": "bar",} + { + "datasource_type": "table", + "database_name": "foo", + "table_name": "bar", + } ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.client.get(url) @@ -185,7 +189,11 @@ def test_external_metadata_by_name_from_sqla_inspector(self): ) # invalid query params - params = prison.dumps({"datasource_type": "table",}) + params = prison.dumps( + { + "datasource_type": "table", + } + ) url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) self.assertIn("error", resp) diff --git a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py index d0e20ccdf516e..c4432e3ad1f93 100644 --- a/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py +++ b/tests/integration_tests/db_engine_specs/base_engine_spec_tests.py @@ -93,7 +93,9 @@ def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name def test_limit_query_without_force(self): self.sql_limit_regex( - "SELECT * FROM a LIMIT 10", "SELECT * FROM a LIMIT 10", limit=11, + "SELECT * FROM a LIMIT 10", + "SELECT * FROM a LIMIT 10", + limit=11, ) def test_limit_query_with_force(self): @@ -399,7 +401,11 @@ def test_get_time_grain_with_unkown_values(): config = app.config.copy() app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] = { - "mysql": {"PT2H": "foo", "weird": "foo", "PT12H": "foo",} + "mysql": { + "PT2H": "foo", + "weird": "foo", + "PT12H": "foo", + } } with app.app_context(): diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index b7405092c5446..549b1109529e9 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -139,8 +139,14 @@ def test_extra_table_metadata(self): self.assertEqual(result, {}) index_metadata = [ - {"name": "clustering", "column_names": ["c_col1", "c_col2", "c_col3"],}, - {"name": "partition", "column_names": ["p_col1", "p_col2", "p_col3"],}, + { + "name": "clustering", + "column_names": ["c_col1", "c_col2", "c_col3"], + }, + { + "name": "partition", + "column_names": ["p_col1", "p_col2", "p_col3"], + }, ] expected_result = { "partitions": {"cols": [["p_col1", "p_col2", "p_col3"]]}, @@ -247,7 +253,12 @@ def test_extract_errors(self): level=ErrorLevel.ERROR, extra={ "engine_name": "Google BigQuery", - "issue_codes": [{"code": 1017, "message": "",}], + "issue_codes": [ + { + "code": 1017, + "message": "", + } + ], }, ) ] diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 7cc6b27fa7c48..ad80f8397ffe1 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -166,7 +166,10 @@ def test_convert_dttm(): def test_df_to_csv() -> None: with pytest.raises(SupersetException): HiveEngineSpec.df_to_sql( - mock.MagicMock(), Table("foobar"), pd.DataFrame(), {"if_exists": "append"}, + mock.MagicMock(), + Table("foobar"), + pd.DataFrame(), + {"if_exists": "append"}, ) diff --git a/tests/integration_tests/db_engine_specs/pinot_tests.py b/tests/integration_tests/db_engine_specs/pinot_tests.py index fa31efdb388d3..803dd67cbacfa 100644 --- a/tests/integration_tests/db_engine_specs/pinot_tests.py +++ b/tests/integration_tests/db_engine_specs/pinot_tests.py @@ -21,7 +21,7 @@ class TestPinotDbEngineSpec(TestDbEngineSpec): - """ Tests pertaining to our Pinot database support """ + """Tests pertaining to our Pinot database support""" def test_pinot_time_expression_sec_one_1d_grain(self): col = column("tstamp") @@ -62,7 +62,8 @@ def test_pinot_time_expression_sec_one_1m_grain(self): expr = PinotEngineSpec.get_timestamp_expr(col, "epoch_s", "P1M") result = str(expr.compile()) self.assertEqual( - result, "DATETRUNC('month', tstamp, 'SECONDS')", + result, + "DATETRUNC('month', tstamp, 'SECONDS')", ) def test_invalid_get_time_expression_arguments(self): diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index dcf5310fecac5..e6eb4fc1d13ea 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -179,7 +179,11 @@ def test_estimate_statement_cost_select_star(self): sql = "SELECT * FROM birth_names" results = PostgresEngineSpec.estimate_statement_cost(sql, cursor) self.assertEqual( - results, {"Start-up cost": 0.00, "Total cost": 1537.91,}, + results, + { + "Start-up cost": 0.00, + "Total cost": 1537.91, + }, ) def test_estimate_statement_invalid_syntax(self): @@ -205,15 +209,27 @@ def test_query_cost_formatter_example_costs(self): DB Eng Specs (postgres): Test test_query_cost_formatter example costs """ raw_cost = [ - {"Start-up cost": 0.00, "Total cost": 1537.91,}, - {"Start-up cost": 10.00, "Total cost": 1537.00,}, + { + "Start-up cost": 0.00, + "Total cost": 1537.91, + }, + { + "Start-up cost": 10.00, + "Total cost": 1537.00, + }, ] result = PostgresEngineSpec.query_cost_formatter(raw_cost) self.assertEqual( result, [ - {"Start-up cost": "0.0", "Total cost": "1537.91",}, - {"Start-up cost": "10.0", "Total cost": "1537.0",}, + { + "Start-up cost": "0.0", + "Total cost": "1537.91", + }, + { + "Start-up cost": "10.0", + "Total cost": "1537.0", + }, ], ) diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 5833c6bdcbfcb..558f4322a0e5d 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -293,7 +293,10 @@ def test_presto_expand_data_with_complex_row_columns(self): ) def test_presto_expand_data_with_complex_row_columns_and_null_values(self): cols = [ - {"name": "row_column", "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",} + { + "name": "row_column", + "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))", + } ] data = [ {"row_column": '[["a"]]'}, @@ -305,7 +308,10 @@ def test_presto_expand_data_with_complex_row_columns_and_null_values(self): cols, data ) expected_cols = [ - {"name": "row_column", "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))",}, + { + "name": "row_column", + "type": "ROW(NESTED_ROW ROW(NESTED_OBJ VARCHAR))", + }, {"name": "row_column.nested_row", "type": "ROW(NESTED_OBJ VARCHAR)"}, {"name": "row_column.nested_row.nested_obj", "type": "VARCHAR"}, ] @@ -786,7 +792,10 @@ def test_select_star_presto_expand_data( True, True, True, - [{"name": "val1"}, {"name": "val2 Iterator[str]: yield salt namespace = get_uuid_namespace(salt) db.session.query(KeyValueEntry).filter_by( - resource=KeyValueResource.APP, uuid=uuid3(namespace, key), + resource=KeyValueResource.APP, + uuid=uuid3(namespace, key), ) db.session.commit() diff --git a/tests/integration_tests/extensions/metastore_cache_test.py b/tests/integration_tests/extensions/metastore_cache_test.py index eb264c983f95b..d9e0e9ee26f5e 100644 --- a/tests/integration_tests/extensions/metastore_cache_test.py +++ b/tests/integration_tests/extensions/metastore_cache_test.py @@ -40,7 +40,8 @@ def cache() -> SupersetMetastoreCache: from superset.extensions.metastore_cache import SupersetMetastoreCache return SupersetMetastoreCache( - namespace=UUID("ee173d1b-ccf3-40aa-941c-985c15224496"), default_timeout=600, + namespace=UUID("ee173d1b-ccf3-40aa-941c-985c15224496"), + default_timeout=600, ) diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index 685cf43f581fa..ef71803aa5db7 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -76,7 +76,9 @@ def _create_dashboards(): def _create_table( - table_name: str, database: "Database", fetch_values_predicate: Optional[str] = None, + table_name: str, + database: "Database", + fetch_values_predicate: Optional[str] = None, ): table = create_table_metadata( table_name=table_name, diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index 996c77f12e0f2..18bec4f17995b 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -492,7 +492,9 @@ }, "metadata": { "timed_refresh_immune_slices": [83], - "filter_scopes": {"83": {"region": {"scope": ["ROOT_ID"], "immune": [83]}},}, + "filter_scopes": { + "83": {"region": {"scope": ["ROOT_ID"], "immune": [83]}}, + }, "expanded_slices": {"83": True}, "refresh_frequency": 0, "default_filters": "{}", diff --git a/tests/integration_tests/form_tests.py b/tests/integration_tests/form_tests.py index b15dfdea79a54..078a9866ee975 100644 --- a/tests/integration_tests/form_tests.py +++ b/tests/integration_tests/form_tests.py @@ -23,11 +23,11 @@ class TestForm(SupersetTestCase): def test_comma_separated_list_field(self): field = CommaSeparatedListField().bind(Form(), "foo") - field.process_formdata([u""]) - self.assertEqual(field.data, [u""]) + field.process_formdata([""]) + self.assertEqual(field.data, [""]) field.process_formdata(["a,comma,separated,list"]) - self.assertEqual(field.data, [u"a", u"comma", u"separated", u"list"]) + self.assertEqual(field.data, ["a", "comma", "separated", "list"]) def test_filter_not_empty_values(self): self.assertEqual(filter_not_empty_values(None), None) diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index f34cf621506d2..67d2a89d866ff 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -287,8 +287,10 @@ def test_export_2_dashboards(self): self.login("admin") birth_dash = self.get_dash_by_slug("births") world_health_dash = self.get_dash_by_slug("world_health") - export_dash_url = "/dashboard/export_dashboards_form?id={}&id={}&action=go".format( - birth_dash.id, world_health_dash.id + export_dash_url = ( + "/dashboard/export_dashboards_form?id={}&id={}&action=go".format( + birth_dash.id, world_health_dash.id + ) ) resp = self.client.get(export_dash_url) resp_data = json.loads(resp.data.decode("utf-8"), object_hook=decode_dashboards) diff --git a/tests/integration_tests/key_value/commands/delete_test.py b/tests/integration_tests/key_value/commands/delete_test.py index 67623461246f6..62f9883370cf1 100644 --- a/tests/integration_tests/key_value/commands/delete_test.py +++ b/tests/integration_tests/key_value/commands/delete_test.py @@ -39,7 +39,10 @@ def key_value_entry() -> KeyValueEntry: from superset.key_value.models import KeyValueEntry entry = KeyValueEntry( - id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE), + id=ID_KEY, + uuid=UUID_KEY, + resource=RESOURCE, + value=pickle.dumps(VALUE), ) db.session.add(entry) db.session.commit() @@ -47,7 +50,9 @@ def key_value_entry() -> KeyValueEntry: def test_delete_id_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry @@ -56,7 +61,9 @@ def test_delete_id_entry( def test_delete_uuid_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry @@ -65,7 +72,9 @@ def test_delete_uuid_entry( def test_delete_entry_missing( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.models import KeyValueEntry diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index de77a6c46badb..2fd4fde4e1dc3 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -43,7 +43,10 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]: from superset.key_value.models import KeyValueEntry entry = KeyValueEntry( - id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE), + id=ID_KEY, + uuid=UUID_KEY, + resource=RESOURCE, + value=pickle.dumps(VALUE), ) db.session.add(entry) db.session.commit() diff --git a/tests/integration_tests/key_value/commands/get_test.py b/tests/integration_tests/key_value/commands/get_test.py index c2c85e987534f..b1800a4c3b9a3 100644 --- a/tests/integration_tests/key_value/commands/get_test.py +++ b/tests/integration_tests/key_value/commands/get_test.py @@ -53,7 +53,8 @@ def test_get_uuid_entry( def test_get_id_entry_missing( - app_context: AppContext, key_value_entry: KeyValueEntry, + app_context: AppContext, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.get import GetKeyValueCommand diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 36de8972a0c72..62a8126ba2ac6 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -40,13 +40,18 @@ def test_update_id_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.models import KeyValueEntry key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, ).run() assert key.id == ID_KEY entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() @@ -55,13 +60,18 @@ def test_update_id_entry( def test_update_uuid_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.models import KeyValueEntry key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, ).run() assert key.uuid == UUID_KEY entry = ( @@ -75,6 +85,9 @@ def test_update_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand key = UpdateKeyValueCommand( - actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=456, + value=NEW_VALUE, ).run() assert key is None diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index 8038614ce5aa9..adb652e66a195 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -40,13 +40,18 @@ def test_upsert_id_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, ).run() assert key.id == ID_KEY entry = ( @@ -57,13 +62,18 @@ def test_upsert_id_entry( def test_upsert_uuid_entry( - app_context: AppContext, admin: User, key_value_entry: KeyValueEntry, + app_context: AppContext, + admin: User, + key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, ).run() assert key.uuid == UUID_KEY entry = ( @@ -78,7 +88,10 @@ def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.models import KeyValueEntry key = UpsertKeyValueCommand( - actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE, + actor=admin, + resource=RESOURCE, + key=456, + value=NEW_VALUE, ).run() assert key.id == 456 db.session.query(KeyValueEntry).filter_by(id=456).delete() diff --git a/tests/integration_tests/log_api_tests.py b/tests/integration_tests/log_api_tests.py index e5f78c754146e..089ac07921c5a 100644 --- a/tests/integration_tests/log_api_tests.py +++ b/tests/integration_tests/log_api_tests.py @@ -77,7 +77,7 @@ def test_not_enabled(self): def test_get_list(self): """ - Log API: Test get list + Log API: Test get list """ admin_user = self.get_user("admin") log = self.insert_log("some_action", admin_user) @@ -95,7 +95,7 @@ def test_get_list(self): def test_get_list_not_allowed(self): """ - Log API: Test get list + Log API: Test get list """ admin_user = self.get_user("admin") log = self.insert_log("action", admin_user) @@ -109,7 +109,7 @@ def test_get_list_not_allowed(self): def test_get_item(self): """ - Log API: Test get item + Log API: Test get item """ admin_user = self.get_user("admin") log = self.insert_log("some_action", admin_user) @@ -127,7 +127,7 @@ def test_get_item(self): def test_delete_log(self): """ - Log API: Test delete (does not exist) + Log API: Test delete (does not exist) """ admin_user = self.get_user("admin") log = self.insert_log("action", admin_user) @@ -140,7 +140,7 @@ def test_delete_log(self): def test_update_log(self): """ - Log API: Test update (does not exist) + Log API: Test update (does not exist) """ admin_user = self.get_user("admin") log = self.insert_log("action", admin_user) diff --git a/tests/integration_tests/migrations/f1410ed7ec95_tests.py b/tests/integration_tests/migrations/f1410ed7ec95_tests.py index 2b48b56762b05..c60d0a74beb5c 100644 --- a/tests/integration_tests/migrations/f1410ed7ec95_tests.py +++ b/tests/integration_tests/migrations/f1410ed7ec95_tests.py @@ -48,7 +48,11 @@ { "filterType": "filter_select", "cascadingFilters": True, - "defaultDataMask": {"filterState": {"value": ["Albania", "Algeria"],},}, + "defaultDataMask": { + "filterState": { + "value": ["Albania", "Algeria"], + }, + }, } ], "filter_sets_configuration": [ @@ -58,7 +62,9 @@ "filterType": "filter_select", "cascadingFilters": True, "defaultDataMask": { - "filterState": {"value": ["Albania", "Algeria"],}, + "filterState": { + "value": ["Albania", "Algeria"], + }, }, }, }, diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index c6388601354d4..8f90d46a9aa9b 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -355,13 +355,15 @@ def test_username_param(self, mocked_get_sqla_engine): if main_db.backend == "mysql": main_db.get_df("USE superset; SELECT 1", username=test_username) mocked_get_sqla_engine.assert_called_with( - schema=None, user_name="test_username_param", + schema=None, + user_name="test_username_param", ) @mock.patch("superset.models.core.create_engine") def test_get_sqla_engine(self, mocked_create_engine): model = Database( - database_name="test_database", sqlalchemy_uri="mysql://root@localhost", + database_name="test_database", + sqlalchemy_uri="mysql://root@localhost", ) model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock( return_value={Exception: SupersetException} @@ -568,7 +570,9 @@ def test_data_for_slices_with_no_query_context(self): slc = ( metadata_db.session.query(Slice) .filter_by( - datasource_id=tbl.id, datasource_type=tbl.type, slice_name="Genders", + datasource_id=tbl.id, + datasource_type=tbl.type, + slice_name="Genders", ) .first() ) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index b2f28186ada32..816267678f9e0 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -91,7 +91,8 @@ def test_schema_deserialization(self): def test_cache(self): table_name = "birth_names" payload = get_query_context( - query_name=table_name, add_postprocessing_operations=True, + query_name=table_name, + add_postprocessing_operations=True, ) payload["force"] = True @@ -443,12 +444,16 @@ def test_handle_sort_by_metrics(self): else: # Should reference the adhoc metric by alias when possible assert re.search( - r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', sql_text, re.IGNORECASE, + r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', + sql_text, + re.IGNORECASE, ) # ORDER BY only columns should always be expressions assert re.search( - r'AVG\([`"\[]?num_boys[`"\]]?\) DESC', sql_text, re.IGNORECASE, + r'AVG\([`"\[]?num_boys[`"\]]?\) DESC', + sql_text, + re.IGNORECASE, ) assert re.search( r"MAX\(CASE.*END\) ASC", sql_text, re.IGNORECASE | re.DOTALL @@ -573,7 +578,10 @@ def test_processing_time_offsets_cache(self): payload["queries"][0]["time_offsets"] = [] query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - rv = query_context.processing_time_offsets(df, query_object,) + rv = query_context.processing_time_offsets( + df, + query_object, + ) self.assertIs(rv["df"], df) self.assertEqual(rv["queries"], []) self.assertEqual(rv["cache_keys"], []) diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index f914941e919e9..a334854d58de3 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -308,7 +308,12 @@ def create_report_email_tabbed_dashboard(tabbed_dashboard): report_schedule = create_report_notification( email_target="target@email.com", dashboard=tabbed_dashboard, - extra={"dashboard_tab_ids": ["TAB-j53G4gtKGF", "TAB-nerWR09Ju",]}, + extra={ + "dashboard_tab_ids": [ + "TAB-j53G4gtKGF", + "TAB-nerWR09Ju", + ] + }, ) yield report_schedule cleanup_report_schedule(report_schedule) @@ -409,7 +414,9 @@ def create_alert_slack_chart_success(): @pytest.fixture( - params=["alert1",] + params=[ + "alert1", + ] ) def create_alert_slack_chart_grace(request): param_config = { @@ -687,7 +694,9 @@ def create_invalid_sql_alert_email_chart(request): @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_email_chart_report_schedule( - screenshot_mock, email_mock, create_report_email_chart, + screenshot_mock, + email_mock, + create_report_email_chart, ): """ ExecuteReport Command: Test chart email report schedule with screenshot @@ -727,7 +736,9 @@ def test_email_chart_report_schedule( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_email_chart_report_schedule_force_screenshot( - screenshot_mock, email_mock, create_report_email_chart_force_screenshot, + screenshot_mock, + email_mock, + create_report_email_chart_force_screenshot, ): """ ExecuteReport Command: Test chart email report schedule with screenshot @@ -769,7 +780,9 @@ def test_email_chart_report_schedule_force_screenshot( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_email_chart_alert_schedule( - screenshot_mock, email_mock, create_alert_email_chart, + screenshot_mock, + email_mock, + create_alert_email_chart, ): """ ExecuteReport Command: Test chart email alert schedule with screenshot @@ -806,7 +819,9 @@ def test_email_chart_alert_schedule( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_email_chart_report_dry_run( - screenshot_mock, email_mock, create_report_email_chart, + screenshot_mock, + email_mock, + create_report_email_chart, ): """ ExecuteReport Command: Test chart email report schedule dry run @@ -831,7 +846,11 @@ def test_email_chart_report_dry_run( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.csv.get_chart_csv_data") def test_email_chart_report_schedule_with_csv( - csv_mock, email_mock, mock_open, mock_urlopen, create_report_email_chart_with_csv, + csv_mock, + email_mock, + mock_open, + mock_urlopen, + create_report_email_chart_with_csv, ): """ ExecuteReport Command: Test chart email report schedule with CSV @@ -1055,7 +1074,9 @@ def test_email_dashboard_report_schedule_force_screenshot( @patch("superset.reports.notifications.slack.WebClient.files_upload") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_chart_report_schedule( - screenshot_mock, file_upload_mock, create_report_slack_chart, + screenshot_mock, + file_upload_mock, + create_report_slack_chart, ): """ ExecuteReport Command: Test chart slack report schedule @@ -1279,7 +1300,9 @@ def test_report_schedule_success_grace_end( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_alert_limit_is_applied( - screenshot_mock, email_mock, create_alert_email_chart, + screenshot_mock, + email_mock, + create_alert_email_chart, ): """ ExecuteReport Command: Test that all alerts apply a SQL limit to stmts @@ -1335,7 +1358,9 @@ def test_email_dashboard_report_fails( ALERTS_ATTACH_REPORTS=True, ) def test_slack_chart_alert( - screenshot_mock, email_mock, create_alert_email_chart, + screenshot_mock, + email_mock, + create_alert_email_chart, ): """ ExecuteReport Command: Test chart slack alert @@ -1392,7 +1417,9 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart): @patch("superset.reports.notifications.slack.WebClient") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_token_callable_chart_report( - screenshot_mock, slack_client_mock_class, create_report_slack_chart, + screenshot_mock, + slack_client_mock_class, + create_report_slack_chart, ): """ ExecuteReport Command: Test chart slack alert (slack token callable) @@ -1504,7 +1531,11 @@ def test_soft_timeout_screenshot(screenshot_mock, email_mock, create_alert_email @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.csv.get_chart_csv_data") def test_soft_timeout_csv( - csv_mock, email_mock, mock_open, mock_urlopen, create_report_email_chart_with_csv, + csv_mock, + email_mock, + mock_open, + mock_urlopen, + create_report_email_chart_with_csv, ): """ ExecuteReport Command: Test fail on generating csv @@ -1528,7 +1559,8 @@ def test_soft_timeout_csv( assert email_mock.call_args[0][0] == OWNER_EMAIL assert_log( - ReportState.ERROR, error_message="A timeout occurred while generating a csv.", + ReportState.ERROR, + error_message="A timeout occurred while generating a csv.", ) @@ -1540,7 +1572,11 @@ def test_soft_timeout_csv( @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.csv.get_chart_csv_data") def test_generate_no_csv( - csv_mock, email_mock, mock_open, mock_urlopen, create_report_email_chart_with_csv, + csv_mock, + email_mock, + mock_open, + mock_urlopen, + create_report_email_chart_with_csv, ): """ ExecuteReport Command: Test fail on generating csv @@ -1723,7 +1759,9 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): @patch("superset.reports.notifications.email.send_email_smtp") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_grace_period_error_flap( - screenshot_mock, email_mock, create_invalid_sql_alert_email_chart, + screenshot_mock, + email_mock, + create_invalid_sql_alert_email_chart, ): """ ExecuteReport Command: Test alert grace period on error @@ -1800,9 +1838,13 @@ def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard assert str(excinfo.value) == "SoftTimeLimitExceeded()" -@pytest.mark.usefixtures("create_report_email_tabbed_dashboard",) +@pytest.mark.usefixtures( + "create_report_email_tabbed_dashboard", +) @patch("superset.reports.notifications.email.send_email_smtp") -@patch("superset.reports.commands.execute.DashboardScreenshot",) +@patch( + "superset.reports.commands.execute.DashboardScreenshot", +) def test_when_tabs_are_selected_it_takes_screenshots_for_every_tabs( dashboard_screenshot_mock, send_email_smtp_mock, diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py index 9ca34198dbdf2..e4d55d9747b98 100644 --- a/tests/integration_tests/security/guest_token_security_tests.py +++ b/tests/integration_tests/security/guest_token_security_tests.py @@ -34,7 +34,8 @@ @mock.patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, + "superset.extensions.feature_flag_manager._feature_flags", + EMBEDDED_SUPERSET=True, ) class TestGuestUserSecurity(SupersetTestCase): # This test doesn't use a dashboard fixture, the next test does. @@ -150,7 +151,8 @@ def test_get_guest_user_roles_implicit(self): @mock.patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, + "superset.extensions.feature_flag_manager._feature_flags", + EMBEDDED_SUPERSET=True, ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") class TestGuestUserDashboardAccess(SupersetTestCase): diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py index 2f88cfa9b49af..ad8d01691a599 100644 --- a/tests/integration_tests/security/migrate_roles_tests.py +++ b/tests/integration_tests/security/migrate_roles_tests.py @@ -102,7 +102,12 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to one with multiple permissions", - {"NewDummy": ("can_read", "can_write",)}, + { + "NewDummy": ( + "can_read", + "can_write", + ) + }, { Pvm("DummyView", "can_list"): (Pvm("NewDummy", "can_read"),), Pvm("DummyView", "can_show"): (Pvm("NewDummy", "can_read"),), @@ -115,7 +120,12 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to one with multiple views", - {"NewDummy": ("can_read", "can_write",)}, + { + "NewDummy": ( + "can_read", + "can_write", + ) + }, { Pvm("DummyView", "can_list"): (Pvm("NewDummy", "can_read"),), Pvm("DummyView", "can_show"): (Pvm("NewDummy", "can_read"),), @@ -132,7 +142,12 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to one with existing permission-view (pvm)", - {"NewDummy": ("can_read", "can_write",)}, + { + "NewDummy": ( + "can_read", + "can_write", + ) + }, { Pvm("DummyView", "can_list"): (Pvm("NewDummy", "can_read"),), Pvm("DummyView", "can_add"): (Pvm("NewDummy", "can_write"),), @@ -143,20 +158,33 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to one with existing multiple permission-view (pvm)", - {"NewDummy": ("can_read", "can_write",)}, + { + "NewDummy": ( + "can_read", + "can_write", + ) + }, { Pvm("DummyView", "can_list"): (Pvm("NewDummy", "can_read"),), Pvm("DummyView", "can_add"): (Pvm("NewDummy", "can_write"),), Pvm("DummySecondView", "can_list"): (Pvm("NewDummy", "can_read"),), Pvm("DummySecondView", "can_add"): (Pvm("NewDummy", "can_write"),), }, - (Pvm("UserDBModelView", "can_list"), Pvm("UserDBModelView", "can_add"),), + ( + Pvm("UserDBModelView", "can_list"), + Pvm("UserDBModelView", "can_add"), + ), ("DummyView",), (), ), ( "Many to one with with old permission that gets deleted", - {"NewDummy": ("can_read", "can_write",)}, + { + "NewDummy": ( + "can_read", + "can_write", + ) + }, { Pvm("DummyView", "can_new_perm"): (Pvm("NewDummy", "can_read"),), Pvm("DummyView", "can_add"): (Pvm("NewDummy", "can_write"),), @@ -167,7 +195,13 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to Many (normally should be a downgrade)", - {"DummyView": ("can_list", "can_show", "can_add",)}, + { + "DummyView": ( + "can_list", + "can_show", + "can_add", + ) + }, { Pvm("NewDummy", "can_read"): ( Pvm("DummyView", "can_list"), @@ -181,13 +215,22 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms): ), ( "Many to Many delete old permissions", - {"DummyView": ("can_list", "can_show", "can_add",)}, + { + "DummyView": ( + "can_list", + "can_show", + "can_add", + ) + }, { Pvm("NewDummy", "can_new_perm1"): ( Pvm("DummyView", "can_list"), Pvm("DummyView", "can_show"), ), - Pvm("NewDummy", "can_new_perm2",): (Pvm("DummyView", "can_add"),), + Pvm( + "NewDummy", + "can_new_perm2", + ): (Pvm("DummyView", "can_add"),), }, (), ("NewDummy",), diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 94369c41ca22b..1e46bfb996c5b 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -218,7 +218,8 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self): @mock.patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, + "superset.extensions.feature_flag_manager._feature_flags", + EMBEDDED_SUPERSET=True, ) class GuestTokenRowLevelSecurityTests(SupersetTestCase): query_obj: Dict[str, Any] = dict( diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index 3468b8c047fb7..b1e661cc2c5bb 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -111,7 +111,10 @@ def test_validate_sql_endpoint_mocked_params(self, get_validator_by_name): get_validator_by_name.return_value = validator validator.validate.return_value = [ SQLValidationAnnotation( - message="This worked", line_number=4, start_column=12, end_column=42, + message="This worked", + line_number=4, + start_column=12, + end_column=42, ) ] diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 223d48a4899a7..bbe062e509ba9 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -598,7 +598,13 @@ def test_filter_on_text_column(text_column_table): result_object = table.query( { "metrics": ["count"], - "filter": [{"col": "foo", "val": ['"text in double quotes"'], "op": "IN",}], + "filter": [ + { + "col": "foo", + "val": ['"text in double quotes"'], + "op": "IN", + } + ], "is_timeseries": False, } ) @@ -608,7 +614,13 @@ def test_filter_on_text_column(text_column_table): result_object = table.query( { "metrics": ["count"], - "filter": [{"col": "foo", "val": ["'text in single quotes'"], "op": "IN",}], + "filter": [ + { + "col": "foo", + "val": ["'text in single quotes'"], + "op": "IN", + } + ], "is_timeseries": False, } ) @@ -618,7 +630,13 @@ def test_filter_on_text_column(text_column_table): result_object = table.query( { "metrics": ["count"], - "filter": [{"col": "foo", "val": ['double quotes " in text'], "op": "IN",}], + "filter": [ + { + "col": "foo", + "val": ['double quotes " in text'], + "op": "IN", + } + ], "is_timeseries": False, } ) @@ -628,7 +646,13 @@ def test_filter_on_text_column(text_column_table): result_object = table.query( { "metrics": ["count"], - "filter": [{"col": "foo", "val": ["single quotes ' in text"], "op": "IN",}], + "filter": [ + { + "col": "foo", + "val": ["single quotes ' in text"], + "op": "IN", + } + ], "is_timeseries": False, } ) @@ -652,7 +676,10 @@ def test_should_generate_closed_and_open_time_filter_range(): database=get_example_database(), ) TableColumn( - column_name="datetime_col", type="TIMESTAMP", table=table, is_dttm=True, + column_name="datetime_col", + type="TIMESTAMP", + table=table, + is_dttm=True, ) SqlMetric(metric_name="count", expression="count(*)", table=table) result_object = table.query( @@ -719,26 +746,48 @@ def _convert_dttm( columns_by_name = { "foo": TableColumn( - column_name="foo", is_dttm=False, table=table, type="STRING", + column_name="foo", + is_dttm=False, + table=table, + type="STRING", ), "bar": TableColumn( - column_name="bar", is_dttm=False, table=table, type="BOOLEAN", + column_name="bar", + is_dttm=False, + table=table, + type="BOOLEAN", ), "baz": TableColumn( - column_name="baz", is_dttm=False, table=table, type="INTEGER", + column_name="baz", + is_dttm=False, + table=table, + type="INTEGER", ), "qux": TableColumn( - column_name="qux", is_dttm=False, table=table, type="FLOAT", + column_name="qux", + is_dttm=False, + table=table, + type="FLOAT", ), "quux": TableColumn( - column_name="quuz", is_dttm=True, table=table, type="STRING", + column_name="quuz", + is_dttm=True, + table=table, + type="STRING", ), "quuz": TableColumn( - column_name="quux", is_dttm=True, table=table, type="TIMESTAMP", + column_name="quux", + is_dttm=True, + table=table, + type="TIMESTAMP", ), } - normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,) + normalized = table._normalize_prequery_result_type( + row, + dimension, + columns_by_name, + ) assert type(normalized) == type(result) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 9028e589252c1..5a98ddebfa945 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -260,8 +260,10 @@ def test_sql_json_schema_access(self): # sqlite doesn't support database creation return - sqllab_test_db_schema_permission_view = security_manager.add_permission_view_menu( - "schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]" + sqllab_test_db_schema_permission_view = ( + security_manager.add_permission_view_menu( + "schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]" + ) ) schema_perm_role = security_manager.add_role("SchemaPermission") security_manager.add_permission_role( @@ -587,13 +589,17 @@ def test_sql_limit(self): ) data = self.run_sql( - "SELECT * FROM birth_names", client_id="sql_limit_6", query_limit=10000, + "SELECT * FROM birth_names", + client_id="sql_limit_6", + query_limit=10000, ) self.assertEqual(len(data["data"]), 1200) self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) data = self.run_sql( - "SELECT * FROM birth_names", client_id="sql_limit_7", query_limit=1200, + "SELECT * FROM birth_names", + client_id="sql_limit_7", + query_limit=1200, ) self.assertEqual(len(data["data"]), 1200) self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.NOT_LIMITED) diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 16299bebea37c..596505a32a485 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -113,7 +113,8 @@ def test_soft_timeout_load_chart_data_into_cache( with pytest.raises(SoftTimeLimitExceeded): with mock.patch.object( - async_queries, "ensure_user_is_set", + async_queries, + "ensure_user_is_set", ) as ensure_user_is_set: ensure_user_is_set.side_effect = SoftTimeLimitExceeded() load_chart_data_into_cache(job_metadata, form_data) @@ -199,7 +200,8 @@ def test_soft_timeout_load_explore_json_into_cache( with pytest.raises(SoftTimeLimitExceeded): with mock.patch.object( - async_queries, "ensure_user_is_set", + async_queries, + "ensure_user_is_set", ) as ensure_user_is_set: ensure_user_is_set.side_effect = SoftTimeLimitExceeded() load_explore_json_into_cache(job_metadata, form_data) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index b402d82f1fc22..765f586ced6c5 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -794,7 +794,11 @@ def test_merge_extra_filters_with_no_extras(self): } merge_extra_form_data(form_data) self.assertEqual( - form_data, {"time_range": "Last 10 days", "adhoc_filters": [],}, + form_data, + { + "time_range": "Last 10 days", + "adhoc_filters": [], + }, ) def test_merge_extra_filters_with_unset_legacy_time_range(self): @@ -826,7 +830,9 @@ def test_merge_extra_filters_with_conflicting_time_ranges(self): form_data = { "time_range": "Last 10 days", "extra_filters": [{"col": "__time_range", "op": "==", "val": "Last week"}], - "extra_form_data": {"time_range": "Last year",}, + "extra_form_data": { + "time_range": "Last year", + }, } merge_extra_filters(form_data) self.assertEqual( diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index 465fdb26ef581..6eb3f8c611487 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -417,7 +417,9 @@ def run_test(metric): "label": "adhoc_metric", "expressionType": "SIMPLE", "aggregate": "SUM", - "column": {"column_name": "sort_column",}, + "column": { + "column_name": "sort_column", + }, } ) @@ -1505,13 +1507,29 @@ def test_get_data(self): test_viz = viz.FilterBoxViz(datasource, form_data) test_viz.dataframes = { "value1": pd.DataFrame( - data=[{"value1": "v1", "metric1": 1}, {"value1": "v2", "metric1": 2},] + data=[ + {"value1": "v1", "metric1": 1}, + {"value1": "v2", "metric1": 2}, + ] ), "value2": pd.DataFrame( - data=[{"value2": "v3", "metric2": 3}, {"value2": "v4", "metric2": 4},] + data=[ + {"value2": "v3", "metric2": 3}, + {"value2": "v4", "metric2": 4}, + ] + ), + "value3": pd.DataFrame( + data=[ + {"value3": "v5"}, + {"value3": "v6"}, + ] + ), + "value4": pd.DataFrame( + data=[ + {"value4": "v7"}, + {"value4": "v8"}, + ] ), - "value3": pd.DataFrame(data=[{"value3": "v5"}, {"value3": "v6"},]), - "value4": pd.DataFrame(data=[{"value4": "v7"}, {"value4": "v8"},]), "value5": pd.DataFrame(), } @@ -1526,8 +1544,14 @@ def test_get_data(self): {"id": "v3", "text": "v3", "metric": 3}, {"id": "v4", "text": "v4", "metric": 4}, ], - "value3": [{"id": "v6", "text": "v6"}, {"id": "v5", "text": "v5"},], - "value4": [{"id": "v7", "text": "v7"}, {"id": "v8", "text": "v8"},], + "value3": [ + {"id": "v6", "text": "v6"}, + {"id": "v5", "text": "v5"}, + ], + "value4": [ + {"id": "v7", "text": "v7"}, + {"id": "v8", "text": "v8"}, + ], "value5": [], "value6": [], } diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py index 36c6b9b4e7301..40cc2075d380e 100644 --- a/tests/unit_tests/columns/test_models.py +++ b/tests/unit_tests/columns/test_models.py @@ -29,7 +29,11 @@ def test_column_model(app_context: None, session: Session) -> None: engine = session.get_bind() Column.metadata.create_all(engine) # pylint: disable=no-member - column = Column(name="ds", type="TIMESTAMP", expression="ds",) + column = Column( + name="ds", + type="TIMESTAMP", + expression="ds", + ) session.add(column) session.flush() diff --git a/tests/unit_tests/core_tests.py b/tests/unit_tests/core_tests.py index c9f96204cf0b0..f7a0047157bb8 100644 --- a/tests/unit_tests/core_tests.py +++ b/tests/unit_tests/core_tests.py @@ -140,7 +140,8 @@ def test_get_column_names(): "My Adhoc Column", ] assert get_column_names( - [STR_COLUMN, SQL_ADHOC_COLUMN], {"my_column": "My Column"}, + [STR_COLUMN, SQL_ADHOC_COLUMN], + {"my_column": "My Column"}, ) == ["My Column", "My Adhoc Column"] diff --git a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py index 5f3015b1b01fa..bddc96eda36e6 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py @@ -35,12 +35,18 @@ def test_update_id_refs_immune_missing( # pylint: disable=invalid-name "position": { "CHART1": { "id": "CHART1", - "meta": {"chartId": 101, "uuid": "uuid1",}, + "meta": { + "chartId": 101, + "uuid": "uuid1", + }, "type": "CHART", }, "CHART2": { "id": "CHART2", - "meta": {"chartId": 102, "uuid": "uuid2",}, + "meta": { + "chartId": 102, + "uuid": "uuid2", + }, "type": "CHART", }, }, diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py index cb5512448e299..a54f5cde61bbc 100644 --- a/tests/unit_tests/datasets/commands/export_test.py +++ b/tests/unit_tests/datasets/commands/export_test.py @@ -69,7 +69,11 @@ def test_export(app_context: None, session: Session) -> None: schema="my_schema", sql=None, params=json.dumps( - {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + } ), perm=None, filter_select_enabled=1, diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 667584c198cfc..07ea8c49d04d9 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -55,7 +55,9 @@ def test_import_dataset(app_context: None, session: Session) -> None: "database_name": "examples", "import_time": 1606677834, }, - "template_params": {"answer": "42",}, + "template_params": { + "answer": "42", + }, "filter_select_enabled": True, "fetch_values_predicate": "foo IN (1, 2)", "extra": {"warning_markdown": "*WARNING*"}, @@ -84,7 +86,9 @@ def test_import_dataset(app_context: None, session: Session) -> None: "expression": "revenue-expenses", "description": None, "python_date_format": None, - "extra": {"certified_by": "User",}, + "extra": { + "certified_by": "User", + }, } ], "database_uuid": database.uuid, @@ -165,7 +169,9 @@ def test_import_column_extra_is_string(app_context: None, session: Session) -> N "database_name": "examples", "import_time": 1606677834, }, - "template_params": {"answer": "42",}, + "template_params": { + "answer": "42", + }, "filter_select_enabled": True, "fetch_values_predicate": "foo IN (1, 2)", "extra": '{"warning_markdown": "*WARNING*"}', diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py index 095b502760912..d21ef8ea60a94 100644 --- a/tests/unit_tests/datasets/test_models.py +++ b/tests/unit_tests/datasets/test_models.py @@ -57,7 +57,10 @@ def test_dataset_model(app_context: None, session: Session) -> None: """, tables=[table], columns=[ - Column(name="position", expression="array_agg(array[longitude,latitude])",), + Column( + name="position", + expression="array_agg(array[longitude,latitude])", + ), ], ) session.add(dataset) @@ -147,7 +150,10 @@ def test_cascade_delete_dataset(app_context: None, session: Session) -> None: """, tables=[table], columns=[ - Column(name="position", expression="array_agg(array[longitude,latitude])",), + Column( + name="position", + expression="array_agg(array[longitude,latitude])", + ), ], ) session.add(dataset) @@ -204,7 +210,11 @@ def test_dataset_attributes(app_context: None, session: Session) -> None: schema="my_schema", sql=None, params=json.dumps( - {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + } ), perm=None, filter_select_enabled=1, @@ -301,7 +311,11 @@ def test_create_physical_sqlatable(app_context: None, session: Session) -> None: schema="my_schema", sql=None, params=json.dumps( - {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + } ), perm=None, filter_select_enabled=1, @@ -576,7 +590,11 @@ def test_create_virtual_sqlatable( FROM some_table""", params=json.dumps( - {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + { + "remote_id": 64, + "database_name": "examples", + "import_time": 1606677834, + } ), perm=None, filter_select_enabled=1, diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 4dc27c0928f99..b112e2cec8ef4 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -77,7 +77,10 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None: ), None, ), - ("SELECT 1 as cnt", None,), + ( + "SELECT 1 as cnt", + None, + ), ( dedent( """ diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index ef3169febe74b..a13895e75e1d5 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -27,7 +27,8 @@ class ProgrammingError(Exception): def test_validate_parameters_simple( - mocker: MockFixture, app_context: AppContext, + mocker: MockFixture, + app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, @@ -43,7 +44,8 @@ def test_validate_parameters_simple( def test_validate_parameters_catalog( - mocker: MockFixture, app_context: AppContext, + mocker: MockFixture, + app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, @@ -78,7 +80,10 @@ def test_validate_parameters_catalog( error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, level=ErrorLevel.WARNING, extra={ - "catalog": {"idx": 0, "url": True,}, + "catalog": { + "idx": 0, + "url": True, + }, "issue_codes": [ { "code": 1003, @@ -96,7 +101,10 @@ def test_validate_parameters_catalog( error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, level=ErrorLevel.WARNING, extra={ - "catalog": {"idx": 2, "url": True,}, + "catalog": { + "idx": 2, + "url": True, + }, "issue_codes": [ { "code": 1003, @@ -112,12 +120,15 @@ def test_validate_parameters_catalog( ] create_engine.assert_called_with( - "gsheets://", service_account_info={}, subject="admin@example.com", + "gsheets://", + service_account_info={}, + subject="admin@example.com", ) def test_validate_parameters_catalog_and_credentials( - mocker: MockFixture, app_context: AppContext, + mocker: MockFixture, + app_context: AppContext, ) -> None: from superset.db_engine_specs.gsheets import ( GSheetsEngineSpec, @@ -151,7 +162,10 @@ def test_validate_parameters_catalog_and_credentials( error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, level=ErrorLevel.WARNING, extra={ - "catalog": {"idx": 2, "url": True,}, + "catalog": { + "idx": 2, + "url": True, + }, "issue_codes": [ { "code": 1003, @@ -167,5 +181,7 @@ def test_validate_parameters_catalog_and_credentials( ] create_engine.assert_called_with( - "gsheets://", service_account_info={}, subject="admin@example.com", + "gsheets://", + service_account_info={}, + subject="admin@example.com", ) diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index 3c8a97aa44f1f..fca6ee5817de1 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -121,7 +121,10 @@ def test_kql_parse_sql(app_context: AppContext) -> None: ], ) def test_kql_convert_dttm( - app_context: AppContext, target_type: str, expected_dttm: str, dttm: datetime, + app_context: AppContext, + target_type: str, + expected_dttm: str, + dttm: datetime, ) -> None: """ Test that date objects are converted correctly. @@ -142,7 +145,10 @@ def test_kql_convert_dttm( ], ) def test_sql_convert_dttm( - app_context: AppContext, target_type: str, expected_dttm: str, dttm: datetime, + app_context: AppContext, + target_type: str, + expected_dttm: str, + dttm: datetime, ) -> None: """ Test that date objects are converted correctly. diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 5c8848280b8ed..ddade3bfdb38c 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -107,13 +107,25 @@ def test_time_exp_mixd_case_col_1y(app_context: AppContext) -> None: @pytest.mark.parametrize( "actual,expected", [ - ("DATE", "CONVERT(DATE, '2019-01-02', 23)",), - ("DATETIME", "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)",), - ("SMALLDATETIME", "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)",), + ( + "DATE", + "CONVERT(DATE, '2019-01-02', 23)", + ), + ( + "DATETIME", + "CONVERT(DATETIME, '2019-01-02T03:04:05.678', 126)", + ), + ( + "SMALLDATETIME", + "CONVERT(SMALLDATETIME, '2019-01-02 03:04:05', 20)", + ), ], ) def test_convert_dttm( - app_context: AppContext, actual: str, expected: str, dttm: datetime, + app_context: AppContext, + actual: str, + expected: str, + dttm: datetime, ) -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec @@ -151,7 +163,9 @@ def test_fetch_data(app_context: AppContext) -> None: from superset.db_engine_specs.mssql import MssqlEngineSpec with mock.patch.object( - MssqlEngineSpec, "pyodbc_rows_to_tuples", return_value="converted", + MssqlEngineSpec, + "pyodbc_rows_to_tuples", + return_value="converted", ) as mock_pyodbc_rows_to_tuples: data = [(1, "foo")] with mock.patch.object(BaseEngineSpec, "fetch_data", return_value=data): @@ -207,7 +221,10 @@ def test_column_datatype_to_string( )""" ), ), - ("SELECT 1 as cnt", None,), + ( + "SELECT 1 as cnt", + None, + ), ( dedent( """ diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 370af3f48d604..eea6b6ec3c362 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -45,7 +45,10 @@ ], ) def test_convert_dttm( - app_context: AppContext, target_type: str, dttm: datetime, result: Optional[str], + app_context: AppContext, + target_type: str, + dttm: datetime, + result: Optional[str], ) -> None: from superset.db_engine_specs.presto import PrestoEngineSpec diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py index 11978737abf6c..5887a9317c7f0 100644 --- a/tests/unit_tests/db_engine_specs/test_teradata.py +++ b/tests/unit_tests/db_engine_specs/test_teradata.py @@ -32,7 +32,10 @@ ], ) def test_apply_top_to_sql_limit( - app_context: AppContext, limit: int, original: str, expected: str, + app_context: AppContext, + limit: int, + original: str, + expected: str, ) -> None: """ Ensure limits are applied to the query correctly diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index ff00c4ff4595d..9962a0f66d0dc 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -45,7 +45,10 @@ ], ) def test_convert_dttm( - app_context: AppContext, target_type: str, dttm: datetime, result: Optional[str], + app_context: AppContext, + target_type: str, + dttm: datetime, + result: Optional[str], ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec diff --git a/tests/unit_tests/fixtures/dataframes.py b/tests/unit_tests/fixtures/dataframes.py index 2a49bd3f8d951..31a275b735ac7 100644 --- a/tests/unit_tests/fixtures/dataframes.py +++ b/tests/unit_tests/fixtures/dataframes.py @@ -168,14 +168,28 @@ single_metric_df = DataFrame( { - "dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]), + "dttm": to_datetime( + [ + "2019-01-01", + "2019-01-01", + "2019-01-02", + "2019-01-02", + ] + ), "country": ["UK", "US", "UK", "US"], "sum_metric": [5, 6, 7, 8], } ) multiple_metrics_df = DataFrame( { - "dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]), + "dttm": to_datetime( + [ + "2019-01-01", + "2019-01-01", + "2019-01-02", + "2019-01-02", + ] + ), "country": ["UK", "US", "UK", "US"], "sum_metric": [5, 6, 7, 8], "count_metric": [1, 2, 3, 4], diff --git a/tests/unit_tests/pandas_postprocessing/test_contribution.py b/tests/unit_tests/pandas_postprocessing/test_contribution.py index a38551474770d..7eb34c4d13f7b 100644 --- a/tests/unit_tests/pandas_postprocessing/test_contribution.py +++ b/tests/unit_tests/pandas_postprocessing/test_contribution.py @@ -48,7 +48,8 @@ def test_contribution(): # cell contribution across row processed_df = contribution( - df, orientation=PostProcessingContributionOrientation.ROW, + df, + orientation=PostProcessingContributionOrientation.ROW, ) assert processed_df.columns.tolist() == [DTTM_ALIAS, "a", "b", "c"] assert_array_equal(processed_df["a"].tolist(), [0.5, 0.25, nan]) diff --git a/tests/unit_tests/pandas_postprocessing/test_cum.py b/tests/unit_tests/pandas_postprocessing/test_cum.py index 6cc5da2807ef2..17cd3c0efc8a4 100644 --- a/tests/unit_tests/pandas_postprocessing/test_cum.py +++ b/tests/unit_tests/pandas_postprocessing/test_cum.py @@ -31,33 +31,49 @@ def test_cum_should_not_side_effect(): _timeseries_df = timeseries_df.copy() pp.cum( - df=timeseries_df, columns={"y": "y2"}, operator="sum", + df=timeseries_df, + columns={"y": "y2"}, + operator="sum", ) assert _timeseries_df.equals(timeseries_df) def test_cum(): # create new column (cumsum) - post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",) + post_df = pp.cum( + df=timeseries_df, + columns={"y": "y2"}, + operator="sum", + ) assert post_df.columns.tolist() == ["label", "y", "y2"] assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"] assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0] assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0] # overwrite column (cumprod) - post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",) + post_df = pp.cum( + df=timeseries_df, + columns={"y": "y"}, + operator="prod", + ) assert post_df.columns.tolist() == ["label", "y"] assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0] # overwrite column (cummin) - post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",) + post_df = pp.cum( + df=timeseries_df, + columns={"y": "y"}, + operator="min", + ) assert post_df.columns.tolist() == ["label", "y"] assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0] # invalid operator with pytest.raises(InvalidPostProcessingError): pp.cum( - df=timeseries_df, columns={"y": "y"}, operator="abc", + df=timeseries_df, + columns={"y": "y"}, + operator="abc", ) diff --git a/tests/unit_tests/pandas_postprocessing/test_diff.py b/tests/unit_tests/pandas_postprocessing/test_diff.py index a491d6ca2b0fb..c77195bbf6d71 100644 --- a/tests/unit_tests/pandas_postprocessing/test_diff.py +++ b/tests/unit_tests/pandas_postprocessing/test_diff.py @@ -41,7 +41,8 @@ def test_diff(): # invalid column reference with pytest.raises(InvalidPostProcessingError): diff( - df=timeseries_df, columns={"abc": "abc"}, + df=timeseries_df, + columns={"abc": "abc"}, ) # diff by columns diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py b/tests/unit_tests/pandas_postprocessing/test_flatten.py index 01a180b2d511b..028d25e9ecdd0 100644 --- a/tests/unit_tests/pandas_postprocessing/test_flatten.py +++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py @@ -21,7 +21,12 @@ def test_flat_should_not_change(): - df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],}) + df = pd.DataFrame( + data={ + "foo": [1, 2, 3], + "bar": [4, 5, 6], + } + ) assert pp.flatten(df).equals(df) @@ -40,7 +45,13 @@ def test_flat_should_flat_datetime_index(): df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]}) assert pp.flatten(df).equals( - pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5, 6],}) + pd.DataFrame( + { + "__timestamp": index, + "foo": [1, 2, 3], + "bar": [4, 5, 6], + } + ) ) diff --git a/tests/unit_tests/pandas_postprocessing/test_pivot.py b/tests/unit_tests/pandas_postprocessing/test_pivot.py index e775df4e3f809..658cb4edcda86 100644 --- a/tests/unit_tests/pandas_postprocessing/test_pivot.py +++ b/tests/unit_tests/pandas_postprocessing/test_pivot.py @@ -34,37 +34,49 @@ def test_flatten_column_after_pivot(): """ # single aggregate cases assert ( - _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",) + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, + column="idx_nulls", + ) == "idx_nulls" ) assert ( - _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column=1234,) + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, + column=1234, + ) == "1234" ) assert ( _flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=Timestamp("2020-09-29T00:00:00"), + aggregates=AGGREGATES_SINGLE, + column=Timestamp("2020-09-29T00:00:00"), ) == "2020-09-29 00:00:00" ) assert ( - _flatten_column_after_pivot(aggregates=AGGREGATES_SINGLE, column="idx_nulls",) + _flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, + column="idx_nulls", + ) == "idx_nulls" ) assert ( _flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"), + aggregates=AGGREGATES_SINGLE, + column=("idx_nulls", "col1"), ) == "col1" ) assert ( _flatten_column_after_pivot( - aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", 1234), + aggregates=AGGREGATES_SINGLE, + column=("idx_nulls", "col1", 1234), ) == "col1, 1234" ) @@ -72,7 +84,8 @@ def test_flatten_column_after_pivot(): # Multiple aggregate cases assert ( _flatten_column_after_pivot( - aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"), + aggregates=AGGREGATES_MULTIPLE, + column=("idx_nulls", "asc_idx", "col1"), ) == "idx_nulls, asc_idx, col1" ) @@ -90,7 +103,11 @@ def test_pivot_without_columns(): """ Make sure pivot without columns returns correct DataFrame """ - df = pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,) + df = pivot( + df=categories_df, + index=["name"], + aggregates=AGGREGATES_SINGLE, + ) assert df.columns.tolist() == ["name", "idx_nulls"] assert len(df) == 101 assert df.sum()[1] == 1050 @@ -235,7 +252,10 @@ def test_pivot_eliminate_cartesian_product_columns(): df=mock_df, index=["dttm"], columns=["a", "b"], - aggregates={"metric": {"operator": "mean"}, "metric2": {"operator": "mean"},}, + aggregates={ + "metric": {"operator": "mean"}, + "metric2": {"operator": "mean"}, + }, drop_missing_columns=False, ) assert list(df.columns) == [ diff --git a/tests/unit_tests/pandas_postprocessing/test_prophet.py b/tests/unit_tests/pandas_postprocessing/test_prophet.py index f341a5e250735..e4f3ed8cfc36d 100644 --- a/tests/unit_tests/pandas_postprocessing/test_prophet.py +++ b/tests/unit_tests/pandas_postprocessing/test_prophet.py @@ -84,31 +84,46 @@ def test_prophet_missing_temporal_column(): with pytest.raises(InvalidPostProcessingError): prophet( - df=df, time_grain="P1M", periods=3, confidence_interval=0.9, + df=df, + time_grain="P1M", + periods=3, + confidence_interval=0.9, ) def test_prophet_incorrect_confidence_interval(): with pytest.raises(InvalidPostProcessingError): prophet( - df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0, + df=prophet_df, + time_grain="P1M", + periods=3, + confidence_interval=0.0, ) with pytest.raises(InvalidPostProcessingError): prophet( - df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0, + df=prophet_df, + time_grain="P1M", + periods=3, + confidence_interval=1.0, ) def test_prophet_incorrect_periods(): with pytest.raises(InvalidPostProcessingError): prophet( - df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8, + df=prophet_df, + time_grain="P1M", + periods=-1, + confidence_interval=0.8, ) def test_prophet_incorrect_time_grain(): with pytest.raises(InvalidPostProcessingError): prophet( - df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8, + df=prophet_df, + time_grain="yearly", + periods=10, + confidence_interval=0.8, ) diff --git a/tests/unit_tests/pandas_postprocessing/test_resample.py b/tests/unit_tests/pandas_postprocessing/test_resample.py index 9568d4ebd126a..9f1aaef3e62f6 100644 --- a/tests/unit_tests/pandas_postprocessing/test_resample.py +++ b/tests/unit_tests/pandas_postprocessing/test_resample.py @@ -107,7 +107,9 @@ def test_resample_after_pivot(): df=df, index=["__timestamp"], columns=["city"], - aggregates={"val": {"operator": "sum"},}, + aggregates={ + "val": {"operator": "sum"}, + }, flatten_columns=False, reset_index=False, ) @@ -118,7 +120,12 @@ def test_resample_after_pivot(): 2022-01-11 3.0 2.0 1.0 2022-01-13 6.0 5.0 4.0 """ - resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,) + resample_df = pp.resample( + df=pivot_df, + rule="1D", + method="asfreq", + fill_value=0, + ) """ val city Chicago LA NY @@ -151,12 +158,16 @@ def test_resample_after_pivot(): def test_resample_should_raise_ex(): with pytest.raises(InvalidPostProcessingError): pp.resample( - df=categories_df, rule="1D", method="asfreq", + df=categories_df, + rule="1D", + method="asfreq", ) with pytest.raises(InvalidPostProcessingError): pp.resample( - df=timeseries_df, rule="1D", method="foobar", + df=timeseries_df, + rule="1D", + method="foobar", ) diff --git a/tests/unit_tests/pandas_postprocessing/test_rolling.py b/tests/unit_tests/pandas_postprocessing/test_rolling.py index 616e4f5bd02d8..4d4c4341b895a 100644 --- a/tests/unit_tests/pandas_postprocessing/test_rolling.py +++ b/tests/unit_tests/pandas_postprocessing/test_rolling.py @@ -90,7 +90,10 @@ def test_rolling(): # incorrect rolling type with pytest.raises(InvalidPostProcessingError): pp.rolling( - df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2, + df=timeseries_df, + columns={"y": "y"}, + rolling_type="abc", + window=2, ) # incorrect rolling type options @@ -191,7 +194,10 @@ def test_rolling_after_pivot_with_multiple_metrics(): """ rolling_df = pp.rolling( df=pivot_df, - columns={"count_metric": "count_metric", "sum_metric": "sum_metric",}, + columns={ + "count_metric": "count_metric", + "sum_metric": "sum_metric", + }, rolling_type="sum", window=2, min_periods=0, diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py index eb1f5f4611248..56ca5ba82fbfc 100644 --- a/tests/unit_tests/tables/test_models.py +++ b/tests/unit_tests/tables/test_models.py @@ -36,7 +36,13 @@ def test_table_model(app_context: None, session: Session) -> None: schema="my_schema", catalog="my_catalog", database=Database(database_name="my_database", sqlalchemy_uri="test://"), - columns=[Column(name="ds", type="TIMESTAMP", expression="ds",)], + columns=[ + Column( + name="ds", + type="TIMESTAMP", + expression="ds", + ) + ], ) session.add(table) session.flush()