From c6d5b5f73001064e79566fc8cc5bbae6d9aa5f1b Mon Sep 17 00:00:00 2001 From: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Date: Thu, 26 Jan 2023 21:53:36 -0300 Subject: [PATCH] feat(ssh_tunnel): Add feature flag to SSH Tunnel API (#22805) --- superset/databases/api.py | 20 +++- superset/databases/commands/create.py | 7 +- .../databases/commands/test_connection.py | 17 ++- superset/databases/commands/update.py | 8 +- .../databases/ssh_tunnel/commands/delete.py | 4 + .../ssh_tunnel/commands/exceptions.py | 5 + .../integration_tests/databases/api_tests.py | 106 ++++++++++++++++-- .../ssh_tunnel/commands/commands_tests.py | 4 +- tests/unit_tests/databases/api_test.py | 8 ++ .../ssh_tunnel/commands/delete_test.py | 11 +- 10 files changed, 171 insertions(+), 19 deletions(-) diff --git a/superset/databases/api.py b/superset/databases/api.py index 4866cbe775fad..572f3b340a9a5 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -75,6 +75,7 @@ from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand from superset.databases.ssh_tunnel.commands.exceptions import ( SSHTunnelDeleteFailedError, + SSHTunnelingNotEnabledError, SSHTunnelNotFoundError, ) from superset.databases.utils import get_table_metadata @@ -349,6 +350,8 @@ def post(self) -> FlaskResponse: exc_info=True, ) return self.response_422(message=str(ex)) + except SSHTunnelingNotEnabledError as ex: + return self.response_400(message=str(ex)) except SupersetException as ex: return self.response(ex.status, message=ex.message) @@ -433,6 +436,8 @@ def put(self, pk: int) -> Response: exc_info=True, ) return self.response_422(message=str(ex)) + except SSHTunnelingNotEnabledError as ex: + return self.response_400(message=str(ex)) @expose("/", methods=["DELETE"]) @protect() @@ -782,8 +787,11 @@ def test_connection(self) -> FlaskResponse: # This validates custom Schema with custom validations except ValidationError as error: return self.response_400(message=error.messages) - TestConnectionDatabaseCommand(item).run() - return self.response(200, message="OK") + try: + TestConnectionDatabaseCommand(item).run() + return self.response(200, message="OK") + except SSHTunnelingNotEnabledError as ex: + return self.response_400(message=str(ex)) @expose("//related_objects/", methods=["GET"]) @protect() @@ -1320,3 +1328,11 @@ def delete_ssh_tunnel(self, pk: int) -> Response: exc_info=True, ) return self.response_422(message=str(ex)) + except SSHTunnelingNotEnabledError as ex: + logger.error( + "Error deleting SSH Tunnel %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_400(message=str(ex)) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index c826d82835744..0ed23549608de 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -20,6 +20,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOCreateFailedError from superset.databases.commands.exceptions import ( @@ -34,6 +35,7 @@ from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand from superset.databases.ssh_tunnel.commands.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelingNotEnabledError, SSHTunnelInvalidError, ) from superset.exceptions import SupersetErrorsException @@ -52,7 +54,7 @@ def run(self) -> Model: try: # Test connection before starting create transaction TestConnectionDatabaseCommand(self._properties).run() - except SupersetErrorsException as ex: + except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex: event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], @@ -78,6 +80,9 @@ def run(self) -> Model: ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + if not is_feature_enabled("SSH_TUNNELING"): + db.session.rollback() + raise SSHTunnelingNotEnabledError() try: # So database.id is not None db.session.flush() diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 002adf12368f3..c5e7dc48f9831 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -25,6 +25,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.exc import DBAPIError, NoSuchModuleError +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.databases.commands.exceptions import ( DatabaseSecurityUnsafeError, @@ -32,6 +33,9 @@ DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelingNotEnabledError, +) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -64,7 +68,7 @@ def __init__(self, data: Dict[str, Any]): self._properties = data.copy() self._model: Optional[Database] = None - def run(self) -> None: # pylint: disable=too-many-statements + def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches self.validate() ex_str = "" uri = self._properties.get("sqlalchemy_uri", "") @@ -107,6 +111,8 @@ def run(self) -> None: # pylint: disable=too-many-statements # Generate tunnel if present in the properties if ssh_tunnel: + if not is_feature_enabled("SSH_TUNNELING"): + raise SSHTunnelingNotEnabledError() # If there's an existing tunnel for that DB we need to use the stored # password, private_key and private_key_password instead if ssh_tunnel_id := ssh_tunnel.pop("id", None): @@ -203,6 +209,15 @@ def ping(engine: Engine) -> bool: ) # bubble up the exception to return a 408 raise ex + except SSHTunnelingNotEnabledError as ex: + event_logger.log_with_context( + action=get_log_connection_action( + "test_connection_error", ssh_tunnel, ex + ), + engine=database.db_engine_spec.__name__, + ) + # bubble up the exception to return a 400 + raise ex except Exception as ex: event_logger.log_with_context( action=get_log_connection_action( diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 2e5931788ee6c..03531803553a5 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -20,6 +20,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.databases.commands.exceptions import ( @@ -33,7 +34,9 @@ from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand from superset.databases.ssh_tunnel.commands.exceptions import ( SSHTunnelCreateFailedError, + SSHTunnelingNotEnabledError, SSHTunnelInvalidError, + SSHTunnelUpdateFailedError, ) from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand from superset.extensions import db, security_manager @@ -102,6 +105,9 @@ def run(self) -> Model: ) if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + if not is_feature_enabled("SSH_TUNNELING"): + db.session.rollback() + raise SSHTunnelingNotEnabledError() existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id) if existing_ssh_tunnel_model is None: # We couldn't found an existing tunnel so we need to create one @@ -118,7 +124,7 @@ def run(self) -> Model: UpdateSSHTunnelCommand( existing_ssh_tunnel_model.id, ssh_tunnel_properties ).run() - except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex: # So we can show the original message raise ex except Exception as ex: diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py index 3ad2fc2a1506c..235ceb697bede 100644 --- a/superset/databases/ssh_tunnel/commands/delete.py +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -19,10 +19,12 @@ from flask_appbuilder.models.sqla import Model +from superset import is_feature_enabled from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.databases.ssh_tunnel.commands.exceptions import ( SSHTunnelDeleteFailedError, + SSHTunnelingNotEnabledError, SSHTunnelNotFoundError, ) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO @@ -37,6 +39,8 @@ def __init__(self, model_id: int): self._model: Optional[SSHTunnel] = None def run(self) -> Model: + if not is_feature_enabled("SSH_TUNNELING"): + raise SSHTunnelingNotEnabledError() self.validate() try: ssh_tunnel = SSHTunnelDAO.delete(self._model) diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py index db2d3173de015..2495961c369a2 100644 --- a/superset/databases/ssh_tunnel/commands/exceptions.py +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -46,6 +46,11 @@ class SSHTunnelCreateFailedError(CommandException): message = _("Creating SSH Tunnel failed for an unknown reason") +class SSHTunnelingNotEnabledError(CommandException): + status = 400 + message = _("SSH Tunneling is not enabled") + + class SSHTunnelRequiredFieldValidationError(ValidationError): def __init__(self, field_name: str) -> None: super().__init__( diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index ae01ccdaf9689..eaa1653847c29 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -285,15 +285,20 @@ def test_create_database(self): @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_create_database_with_ssh_tunnel( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_get_all_schema_names, ): """ Database API: Test create with SSH Tunnel """ + mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": @@ -328,15 +333,23 @@ def test_create_database_with_ssh_tunnel( @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") + @mock.patch("superset.databases.commands.update.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_update_database_with_ssh_tunnel( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_update_is_feature_enabled, + mock_get_all_schema_names, ): """ - Database API: Test update with SSH Tunnel + Database API: Test update Database with SSH Tunnel """ + mock_create_is_feature_enabled.return_value = True + mock_update_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": @@ -381,15 +394,23 @@ def test_update_database_with_ssh_tunnel( @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") + @mock.patch("superset.databases.commands.update.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_update_ssh_tunnel_via_database_api( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_update_is_feature_enabled, + mock_get_all_schema_names, ): """ - Database API: Test update with SSH Tunnel + Database API: Test update SSH Tunnel via Database API """ + mock_create_is_feature_enabled.return_value = True + mock_update_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() @@ -456,12 +477,17 @@ def test_update_ssh_tunnel_via_database_api( @mock.patch( "superset.models.core.Database.get_all_schema_names", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") def test_cascade_delete_ssh_tunnel( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_get_all_schema_names, + mock_create_is_feature_enabled, ): """ - Database API: Test create with SSH Tunnel + Database API: SSH Tunnel gets deleted if Database gets deleted """ + mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": @@ -502,15 +528,20 @@ def test_cascade_delete_ssh_tunnel( @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_do_not_create_database_if_ssh_tunnel_creation_fails( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_get_all_schema_names, ): """ - Database API: Test create with SSH Tunnel + Database API: Test Database is not created if SSH Tunnel creation fails """ + mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": @@ -548,15 +579,20 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) + @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_get_database_returns_related_ssh_tunnel( - self, mock_test_connection_database_command_run, mock_get_all_schema_names + self, + mock_test_connection_database_command_run, + mock_create_is_feature_enabled, + mock_get_all_schema_names, ): """ Database API: Test GET Database returns its related SSH Tunnel """ + mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": @@ -595,6 +631,56 @@ def test_get_database_returns_related_ssh_tunnel( db.session.delete(model) db.session.commit() + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( + self, + mock_test_connection_database_command_run, + mock_get_all_schema_names, + ): + """ + Database API: Test raises SSHTunneling feature flag not enabled + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel-7", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, {"message": "SSH Tunneling is not enabled"}) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one_or_none() + ) + assert model_ssh_tunnel is None + # Cleanup + model = ( + db.session.query(Database) + .filter(Database.database_name == "test-db-with-ssh-tunnel-7") + .one_or_none() + ) + # the DB should not be created + assert model is None + def test_create_database_invalid_configuration_method(self): """ Database API: Test create with an invalid configuration method. diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py index 75e5a55e862c9..86c280b9bb1c4 100644 --- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -67,8 +67,10 @@ def test_update_ssh_tunnel_not_found(self, mock_g): class TestDeleteSSHTunnelCommand(SupersetTestCase): @mock.patch("superset.utils.core.g") - def test_delete_ssh_tunnel_not_found(self, mock_g): + @mock.patch("superset.databases.ssh_tunnel.commands.delete.is_feature_enabled") + def test_delete_ssh_tunnel_not_found(self, mock_g, mock_delete_is_feature_enabled): mock_g.user = security_manager.find_user("admin") + mock_delete_is_feature_enabled.return_value = True # We have not created a SSH Tunnel yet so id = 1 is invalid command = DeleteSSHTunnelCommand(1) with pytest.raises(SSHTunnelNotFoundError) as excinfo: diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index fe4211289caf8..68a9add12e9a4 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -241,6 +241,10 @@ def test_delete_ssh_tunnel( # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled", + return_value=True, + ) # Create our SSHTunnel tunnel = SSHTunnel( @@ -313,6 +317,10 @@ def test_delete_ssh_tunnel_not_found( # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") mocker.patch("superset.utils.log.DBEventLogger.log") + mocker.patch( + "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled", + return_value=True, + ) # Create our SSHTunnel tunnel = SSHTunnel( diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index 17afebfa0fecc..b5adf765fa5ab 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -18,6 +18,7 @@ from typing import Iterator import pytest +from pytest_mock import MockFixture from sqlalchemy.orm.session import Session @@ -50,7 +51,9 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_delete_ssh_tunnel_command(session_with_data: Session) -> None: +def test_delete_ssh_tunnel_command( + mocker: MockFixture, session_with_data: Session +) -> None: from superset.databases.dao import DatabaseDAO from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand from superset.databases.ssh_tunnel.models import SSHTunnel @@ -60,9 +63,11 @@ def test_delete_ssh_tunnel_command(session_with_data: Session) -> None: assert result assert isinstance(result, SSHTunnel) assert 1 == result.database_id - + mocker.patch( + "superset.databases.ssh_tunnel.commands.delete.is_feature_enabled", + return_value=True, + ) DeleteSSHTunnelCommand(1).run() - result = DatabaseDAO.get_ssh_tunnel(1) assert result is None