From a5cf0e4ad2e22cc0086dc16324da5800d9c73617 Mon Sep 17 00:00:00 2001 From: Antonio Rivero Date: Mon, 12 Dec 2022 17:28:23 -0300 Subject: [PATCH] SSH Tunnel: - Using nested transactions so we can rollback if anything fails --- superset/databases/commands/create.py | 4 +--- superset/databases/ssh_tunnel/commands/create.py | 16 +++++++++++++--- tests/integration_tests/databases/api_tests.py | 15 +++------------ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index c826d82835744..df12388cc5f08 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -81,9 +81,7 @@ def run(self) -> Model: try: # So database.id is not None db.session.flush() - ssh_tunnel = CreateSSHTunnelCommand( - database.id, ssh_tunnel_properties - ).run() + CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run() except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index b2e62f340b0bb..9c17149ba3d00 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -28,7 +28,7 @@ SSHTunnelRequiredFieldValidationError, ) from superset.databases.ssh_tunnel.dao import SSHTunnelDAO -from superset.extensions import event_logger +from superset.extensions import db, event_logger logger = logging.getLogger(__name__) @@ -39,12 +39,22 @@ def __init__(self, database_id: int, data: Dict[str, Any]): self._properties["database_id"] = database_id def run(self) -> Model: - self.validate() - try: + # Start nested transaction since we are always creating the tunnel + # through a DB command (Create or Update). Without this, we cannot + # safely rollback changes to databases if any, i.e, things like + # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail + db.session.begin_nested() + self.validate() tunnel = SSHTunnelDAO.create(self._properties, commit=False) except DAOCreateFailedError as ex: + # Rollback nested transaction + db.session.rollback() raise SSHTunnelCreateFailedError() from ex + except SSHTunnelInvalidError as ex: + # Rollback nested transaction + db.session.rollback() + raise ex return tunnel diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index cb170a9f2d972..1e83aa0f804e6 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -510,28 +510,19 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails( if example_db.backend == "sqlite": return ssh_tunnel_properties = { - "server_address_failure": "123.132.123.1", + "server_address": "123.132.123.1", } database_data = { "database_name": "test-db-failure-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } - fail_message = { - "message": { - "ssh_tunnel": { - "server_address_failure": ["Unknown field."], - "server_address": ["Missing data for required field."], - "server_port": ["Missing data for required field."], - "username": ["Missing data for required field."], - } - } - } + fail_message = {"message": "SSH Tunnel parameters are invalid."} uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 400) + self.assertEqual(rv.status_code, 422) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id"))