Skip to content

Commit

Permalink
SSH Tunnel:
Browse files Browse the repository at this point in the history
- Using nested transactions so we can rollback if anything fails
  • Loading branch information
Antonio-RiveroMartnez committed Dec 12, 2022
1 parent 8e910f1 commit a5cf0e4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
4 changes: 1 addition & 3 deletions superset/databases/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def run(self) -> Model:
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
Expand Down
16 changes: 13 additions & 3 deletions superset/databases/ssh_tunnel/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SSHTunnelRequiredFieldValidationError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.extensions import event_logger
from superset.extensions import db, event_logger

logger = logging.getLogger(__name__)

Expand All @@ -39,12 +39,22 @@ def __init__(self, database_id: int, data: Dict[str, Any]):
self._properties["database_id"] = database_id

def run(self) -> Model:
self.validate()

try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex

return tunnel

Expand Down
15 changes: 3 additions & 12 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,28 +510,19 @@ def test_do_not_create_database_if_ssh_tunnel_creation_fails(
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address_failure": "123.132.123.1",
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {
"message": {
"ssh_tunnel": {
"server_address_failure": ["Unknown field."],
"server_address": ["Missing data for required field."],
"server_port": ["Missing data for required field."],
"username": ["Missing data for required field."],
}
}
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}

uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
Expand Down

0 comments on commit a5cf0e4

Please sign in to comment.