Skip to content

Commit

Permalink
feat(ssh_tunnel): APIs for SSH Tunnels (#22199)
Browse files Browse the repository at this point in the history
Co-authored-by: hughhhh <hughmil3s@gmail.com>
  • Loading branch information
Antonio-RiveroMartnez and hughhhh authored Jan 3, 2023
1 parent 394afc1 commit 9b09fc7
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 15 deletions.
1 change: 1 addition & 0 deletions superset/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods
"validate_sql": "read",
"get_data": "read",
"samples": "read",
"delete_ssh_tunnel": "write",
}

EXTRA_FORM_DATA_APPEND_KEYS = {
Expand Down
111 changes: 111 additions & 0 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
ValidateSQLRequest,
ValidateSQLResponse,
)
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelNotFoundError,
)
from superset.databases.utils import get_table_metadata
from superset.db_engine_specs import get_available_engine_specs
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
Expand All @@ -80,6 +85,7 @@
from superset.models.core import Database
from superset.superset_typing import FlaskResponse
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
from superset.utils.ssh_tunnel import mask_password_info
from superset.views.base import json_errors_response
from superset.views.base_api import (
BaseSupersetModelRestApi,
Expand Down Expand Up @@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"available",
"validate_parameters",
"validate_sql",
"delete_ssh_tunnel",
}
resource_name = "database"
class_permission_name = "Database"
Expand Down Expand Up @@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
ValidateSQLResponse,
)

@expose("/<int:pk>", methods=["GET"])
@protect()
@safe
def get(self, pk: int, **kwargs: Any) -> Response:
"""Get a database
---
get:
description: >-
Get a database
parameters:
- in: path
schema:
type: integer
description: The database id
name: pk
responses:
200:
description: Database
content:
application/json:
schema:
type: object
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
data = self.get_headless(pk, **kwargs)
try:
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk):
payload = data.json
payload["result"]["ssh_tunnel"] = ssh_tunnel.data
return payload
return data
except SupersetException as ex:
return self.response(ex.status, message=ex.message)

@expose("/", methods=["POST"])
@protect()
@safe
Expand Down Expand Up @@ -280,6 +328,12 @@ def post(self) -> FlaskResponse:
if new_model.driver:
item["driver"] = new_model.driver

# Return SSH Tunnel and hide passwords if any
if item.get("ssh_tunnel"):
item["ssh_tunnel"] = mask_password_info(
new_model.ssh_tunnel # pylint: disable=no-member
)

return self.response(201, id=new_model.id, result=item)
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
Expand Down Expand Up @@ -361,6 +415,9 @@ def put(self, pk: int) -> Response:
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
if changed_model.parameters:
item["parameters"] = changed_model.parameters
# Return SSH Tunnel and hide passwords if any
if item.get("ssh_tunnel"):
item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel)
return self.response(200, id=changed_model.id, result=item)
except DatabaseNotFoundError:
return self.response_404()
Expand Down Expand Up @@ -1206,3 +1263,57 @@ def validate_parameters(self) -> FlaskResponse:
command = ValidateDatabaseParametersCommand(payload)
command.run()
return self.response(200, message="OK")

@expose("/<int:pk>/ssh_tunnel/", methods=["DELETE"])
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
f".delete_ssh_tunnel",
log_to_statsd=False,
)
def delete_ssh_tunnel(self, pk: int) -> Response:
"""Deletes a SSH Tunnel
---
delete:
description: >-
Deletes a SSH Tunnel.
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: SSH Tunnel deleted
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
DeleteSSHTunnelCommand(pk).run()
return self.response(200, message="OK")
except SSHTunnelNotFoundError:
return self.response_404()
except SSHTunnelDeleteFailedError as ex:
logger.error(
"Error deleting SSH Tunnel %s: %s",
self.__class__.__name__,
str(ex),
exc_info=True,
)
return self.response_422(message=str(ex))
33 changes: 24 additions & 9 deletions superset/databases/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
)
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager

Expand Down Expand Up @@ -71,17 +75,28 @@ def run(self) -> Model:
try:
database = DatabaseDAO.create(self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
db.session.flush()

ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
ssh_tunnel = SSHTunnelDAO.create(
{
**ssh_tunnel_properties,
"database_id": database.id,
},
commit=False,
)
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex

# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
Expand Down
35 changes: 32 additions & 3 deletions superset/databases/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from marshmallow import ValidationError

from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOUpdateFailedError
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseExistsValidationError,
Expand All @@ -30,6 +30,12 @@
DatabaseUpdateFailedError,
)
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceType
Expand Down Expand Up @@ -94,10 +100,33 @@ def run(self) -> Model:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)

if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex

db.session.commit()

except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
return database

Expand Down
1 change: 1 addition & 0 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ class Meta: # pylint: disable=too-few-public-methods
)
is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)


class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
Expand Down
16 changes: 13 additions & 3 deletions superset/databases/ssh_tunnel/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SSHTunnelRequiredFieldValidationError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.extensions import event_logger
from superset.extensions import db, event_logger

logger = logging.getLogger(__name__)

Expand All @@ -39,12 +39,22 @@ def __init__(self, database_id: int, data: Dict[str, Any]):
self._properties["database_id"] = database_id

def run(self) -> Model:
self.validate()

try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex

return tunnel

Expand Down
10 changes: 10 additions & 0 deletions superset/databases/ssh_tunnel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

import sqlalchemy as sa
from flask import current_app
from flask_appbuilder import Model
Expand Down Expand Up @@ -64,3 +66,11 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
private_key_password = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)

@property
def data(self) -> Dict[str, Any]:
return {
"server_address": self.server_address,
"server_port": self.server_port,
"username": self.username,
}
30 changes: 30 additions & 0 deletions superset/utils/ssh_tunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict

from superset.constants import PASSWORD_MASK


def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
ssh_tunnel["password"] = PASSWORD_MASK
if ssh_tunnel.pop("private_key", None) is not None:
ssh_tunnel["private_key"] = PASSWORD_MASK
if ssh_tunnel.pop("private_key_password", None) is not None:
ssh_tunnel["private_key_password"] = PASSWORD_MASK
return ssh_tunnel
Loading

0 comments on commit 9b09fc7

Please sign in to comment.