Skip to content

Commit

Permalink
SSH Tunnel:
Browse files Browse the repository at this point in the history
  - Check whether the feature flag is enabled when calling any of our API endpoints that use SSH Tunnels
  - Update tests
  - Add new test to check the error message from the new exception and its status
  • Loading branch information
Antonio-RiveroMartnez committed Jan 20, 2023
1 parent 13a186b commit 9e6516e
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 19 deletions.
20 changes: 18 additions & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelNotFoundError,
)
from superset.databases.utils import get_table_metadata
Expand Down Expand Up @@ -349,6 +350,8 @@ def post(self) -> FlaskResponse:
exc_info=True,
)
return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex:
return self.response_400(message=str(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)

Expand Down Expand Up @@ -433,6 +436,8 @@ def put(self, pk: int) -> Response:
exc_info=True,
)
return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex:
return self.response_400(message=str(ex))

@expose("/<int:pk>", methods=["DELETE"])
@protect()
Expand Down Expand Up @@ -782,8 +787,11 @@ def test_connection(self) -> FlaskResponse:
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK")
try:
TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK")
except SSHTunnelingNotEnabledError as ex:
return self.response_400(message=str(ex))

@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
Expand Down Expand Up @@ -1320,3 +1328,11 @@ def delete_ssh_tunnel(self, pk: int) -> Response:
exc_info=True,
)
return self.response_422(message=str(ex))
except SSHTunnelingNotEnabledError as ex:
logger.error(
"Error deleting SSH Tunnel %s: %s",
self.__class__.__name__,
str(ex),
exc_info=True,
)
return self.response_400(message=str(ex))
7 changes: 6 additions & 1 deletion superset/databases/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError

from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOCreateFailedError
from superset.databases.commands.exceptions import (
Expand All @@ -34,6 +35,7 @@
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
)
from superset.exceptions import SupersetErrorsException
Expand All @@ -52,7 +54,7 @@ def run(self) -> Model:
try:
# Test connection before starting create transaction
TestConnectionDatabaseCommand(self._properties).run()
except SupersetErrorsException as ex:
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
Expand All @@ -78,6 +80,9 @@ def run(self) -> Model:

ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
try:
# So database.id is not None
db.session.flush()
Expand Down
17 changes: 16 additions & 1 deletion superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
from sqlalchemy.engine import Engine
from sqlalchemy.exc import DBAPIError, NoSuchModuleError

from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelingNotEnabledError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
Expand Down Expand Up @@ -64,7 +68,7 @@ def __init__(self, data: Dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None

def run(self) -> None: # pylint: disable=too-many-statements
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
self.validate()
ex_str = ""
uri = self._properties.get("sqlalchemy_uri", "")
Expand Down Expand Up @@ -107,6 +111,8 @@ def run(self) -> None: # pylint: disable=too-many-statements

# Generate tunnel if present in the properties
if ssh_tunnel:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
# If there's an existing tunnel for that DB we need to use the stored
# password, private_key and private_key_password instead
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
Expand Down Expand Up @@ -203,6 +209,15 @@ def ping(engine: Engine) -> bool:
)
# bubble up the exception to return a 408
raise ex
except SSHTunnelingNotEnabledError as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error", ssh_tunnel, ex
),
engine=database.db_engine_spec.__name__,
)
# bubble up the exception to return a 400
raise ex
except Exception as ex:
event_logger.log_with_context(
action=get_log_connection_action(
Expand Down
8 changes: 7 additions & 1 deletion superset/databases/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError

from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.commands.exceptions import (
Expand All @@ -33,7 +34,9 @@
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
SSHTunnelUpdateFailedError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.extensions import db, security_manager
Expand Down Expand Up @@ -102,6 +105,9 @@ def run(self) -> Model:
)

if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
Expand All @@ -118,7 +124,7 @@ def run(self) -> Model:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
Expand Down
4 changes: 4 additions & 0 deletions superset/databases/ssh_tunnel/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

from flask_appbuilder.models.sqla import Model

from superset import is_feature_enabled
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelNotFoundError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
Expand All @@ -37,6 +39,8 @@ def __init__(self, model_id: int):
self._model: Optional[SSHTunnel] = None

def run(self) -> Model:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
self.validate()
try:
ssh_tunnel = SSHTunnelDAO.delete(self._model)
Expand Down
5 changes: 5 additions & 0 deletions superset/databases/ssh_tunnel/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class SSHTunnelCreateFailedError(CommandException):
message = _("Creating SSH Tunnel failed for an unknown reason")


class SSHTunnelingNotEnabledError(CommandException):
status = 400
message = _("SSH Tunneling is not enabled")


class SSHTunnelRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
Expand Down
Loading

0 comments on commit 9e6516e

Please sign in to comment.