Skip to content

Commit

Permalink
Improve SQLAlchemy store parsing and documentation (mlflow#1297)
Browse files Browse the repository at this point in the history
* scrub '+{driver}' from sqlalchemy URI when identifying db_type

* enhancements to db uri parsing in SQLAlchemy store
- allow dialect+driver://... scheme format
- raise error if db type is unsupported
- raise error if db driver is not believable

* improve documentation on sqlalchemy connection strings/db_uris

* leave driver validation to SQLAlchemy
  • Loading branch information
drewmcdonald authored and mparkhe committed May 27, 2019
1 parent 3482ace commit e68b6c3
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 19 deletions.
13 changes: 7 additions & 6 deletions docs/source/tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ call :py:func:`mlflow.set_tracking_uri`.
There are different kinds of remote tracking URIs:

- Local file path (specified as ``file:/my/local/dir``), where data is just directly stored locally.
- Database encoded as a connection string (specified as ``db_type://<user_name>:<password>@<host>:<port>/<database_name>``)
- Database encoded as ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. Mlflow supports the dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``. For more details, see `SQLAlchemy database uri <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_.
- HTTP server (specified as ``https://my-server:5000``), which is a server hosting an :ref:`MLFlow tracking server <tracking_server>`.
- Databricks workspace (specified as ``databricks`` or as ``databricks://<profileName>``, a `Databricks CLI profile <https://github.com/databricks/databricks-cli#installation>`_.

Expand Down Expand Up @@ -265,11 +265,12 @@ The backend store is where MLflow Tracking Server stores experiment and run meta
params, metrics, and tags for runs. MLflow supports two types of backend stores: *file store* and
*database-backed store*.

Use ``--backend-store-uri`` to configure type of backend store. This can be a local path *file
store* specified as ``./path_to_store`` or ``file:/path_to_store``, or a SQL connection string
for a *database-backed store*. For the latter, the argument must be a SQL connection string
specified as ``db_type://<user_name>:<password>@<host>:<port>/<database_name>``. Supported
database types are ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
Use ``--backend-store-uri`` to configure the type of backend store. You specify a *file store*
backend as ``./path_to_store`` or ``file:/path_to_store`` and a *database-backed store* as
`SQLAlchemy database URI <https://docs.sqlalchemy.org/en/latest/core/engines
.html#database-urls>`_. The database URI typically takes the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``.
MLflow supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
Drivers are optional. If you do not specify a driver, SQLAlchemy uses a dialect's default driver.
For backwards compatibility, ``--file-store`` is an alias for ``--backend-store-uri``.

.. important::
Expand Down
53 changes: 42 additions & 11 deletions mlflow/store/sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,51 @@
INVALID_STATE, RESOURCE_DOES_NOT_EXIST, INTERNAL_ERROR
from mlflow.tracking.utils import _is_local_uri
from mlflow.utils.file_utils import mkdir, local_file_uri_to_path
from mlflow.utils.validation import _validate_batch_log_limits, _validate_batch_log_data,\
_validate_run_id, _validate_metric
from mlflow.utils.validation import _validate_batch_log_limits, _validate_batch_log_data, \
_validate_run_id, _validate_metric, _validate_db_type_string
from mlflow.store.db.utils import _upgrade_db, _get_alembic_config, _get_schema_version
from mlflow.store.dbmodels.initial_models import Base as InitialBase


_logger = logging.getLogger(__name__)


_INVALID_DB_URI_MSG = "Please refer to https://mlflow.org/docs/latest/tracking.html#storage for " \
"format specifications."


def _parse_db_uri_extract_db_type(db_uri):
"""
Parse the specified DB URI to extract the database type. Confirm the database type is
supported. If a driver is specified, confirm it passes a plausible regex.
"""
scheme = urllib.parse.urlparse(db_uri).scheme
scheme_plus_count = scheme.count('+')

if scheme_plus_count == 0:
db_type = scheme
elif scheme_plus_count == 1:
db_type, _ = scheme.split('+')
else:
error_msg = "Invalid database URI: '%s'. %s" % (db_uri, _INVALID_DB_URI_MSG)
raise MlflowException(error_msg, INVALID_PARAMETER_VALUE)

_validate_db_type_string(db_type)

return db_type


class SqlAlchemyStore(AbstractStore):
"""
SQLAlchemy compliant backend store for tracking meta data for MLflow entities. Currently
supported database types are ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``. This store
interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
SQLAlchemy compliant backend store for tracking meta data for MLflow entities. MLflow
supports the database dialects ``mysql``, ``mssql``, ``sqlite``, and ``postgresql``.
As specified in the
`SQLAlchemy docs <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_ ,
the database URI is expected in the format
``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. If you do not
specify a driver, SQLAlchemy uses a dialect's default driver.
This store interacts with SQL store using SQLAlchemy abstractions defined for MLflow entities.
:py:class:`mlflow.store.dbmodels.models.SqlExperiment`,
:py:class:`mlflow.store.dbmodels.models.SqlRun`,
:py:class:`mlflow.store.dbmodels.models.SqlTag`,
Expand All @@ -51,17 +82,17 @@ def __init__(self, db_uri, default_artifact_root):
"""
Create a database backed store.
:param db_uri: SQL connection string used by SQLAlchemy Engine to connect to the database.
Argument is expected to be in the format:
``db_type://<user_name>:<password>@<host>:<port>/<database_name>`
Supported database types are ``mysql``, ``mssql``, ``sqlite``,
and ``postgresql``.
:param db_uri: The SQLAlchemy database URI string to connect to the database. See
the `SQLAlchemy docs
<https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
for format specifications. Mlflow supports the dialects ``mysql``,
``mssql``, ``sqlite``, and ``postgresql``.
:param default_artifact_root: Path/URI to location suitable for large data (such as a blob
store object, DBFS path, or shared NFS file system).
"""
super(SqlAlchemyStore, self).__init__()
self.db_uri = db_uri
self.db_type = urllib.parse.urlparse(db_uri).scheme
self.db_type = _parse_db_uri_extract_db_type(db_uri)
self.artifact_root_uri = default_artifact_root
self.engine = sqlalchemy.create_engine(db_uri)
insp = sqlalchemy.inspect(self.engine)
Expand Down
10 changes: 10 additions & 0 deletions mlflow/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.dbmodels.db_types import DATABASE_ENGINES

_VALID_PARAM_AND_METRIC_NAMES = re.compile(r"^[/\w.\- ]*$")

Expand All @@ -30,6 +31,8 @@
MAX_TAG_VAL_LENGTH = 250
MAX_ENTITY_KEY_LENGTH = 250

_UNSUPPORTED_DB_TYPE_MSG = "Supported database engines are {%s}" % ', '.join(DATABASE_ENGINES)


def bad_path_message(name):
return (
Expand Down Expand Up @@ -195,3 +198,10 @@ def _validate_experiment_artifact_location(artifact_location):
raise MlflowException("Artifact location cannot be a runs:/ URI. Given: '%s'"
% artifact_location,
error_code=INVALID_PARAMETER_VALUE)


def _validate_db_type_string(db_type):
"""validates db_type parsed from DB URI is supported"""
if db_type not in DATABASE_ENGINES:
error_msg = "Invalid database engine: '%s'. '%s'" % (db_type, _UNSUPPORTED_DB_TYPE_MSG)
raise MlflowException(error_msg, INVALID_PARAMETER_VALUE)
42 changes: 41 additions & 1 deletion tests/store/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mlflow.store.dbmodels import models
from mlflow import entities
from mlflow.exceptions import MlflowException
from mlflow.store.sqlalchemy_store import SqlAlchemyStore
from mlflow.store.sqlalchemy_store import SqlAlchemyStore, _parse_db_uri_extract_db_type
from mlflow.utils.search_utils import SearchFilter
from tests.resources.db.initial_models import Base as InitialBase
from tests.integration.utils import invoke_cli_runner
Expand All @@ -31,6 +31,46 @@
ARTIFACT_URI = 'artifact_folder'


class TestParseDbUri(unittest.TestCase):

def test_correct_db_type_from_uri(self):
# try each the main drivers per supported database type
target_db_type_uris = {
'sqlite': ('pysqlite', 'pysqlcipher'),
'postgresql': ('psycopg2', 'pg8000', 'psycopg2cffi',
'pypostgresql', 'pygresql', 'zxjdbc'),
'mysql': ('mysqldb', 'pymysql', 'mysqlconnector', 'cymysql',
'oursql', 'mysqldb', 'gaerdbms', 'pyodbc', 'zxjdbc'),
'mssql': ('pyodbc', 'mxodbc', 'pymssql', 'zxjdbc', 'adodbapi')
}
for target_db_type, drivers in target_db_type_uris.items():
# try the driver-less version, which will revert SQLAlchemy to the default driver
uri = "%s://..." % target_db_type
parsed_db_type = _parse_db_uri_extract_db_type(uri)
self.assertEqual(target_db_type, parsed_db_type)
# try each of the popular drivers (per SQLAlchemy's dialect pages)
for driver in drivers:
uri = "%s+%s://..." % (target_db_type, driver)
parsed_db_type = _parse_db_uri_extract_db_type(uri)
self.assertEqual(target_db_type, parsed_db_type)

def _db_uri_error(self, db_uris, expected_message_part):
for db_uri in db_uris:
with self.assertRaises(MlflowException) as e:
_parse_db_uri_extract_db_type(db_uri)
self.assertIn(expected_message_part, e.exception.message)

def test_fail_on_unsupported_db_type(self):
bad_db_uri_strings = ['oracle://...', 'oracle+cx_oracle://...',
'snowflake://...', '://...', 'abcdefg']
self._db_uri_error(bad_db_uri_strings, "Supported database engines are ")

def test_fail_on_multiple_drivers(self):
bad_db_uri_strings = ['mysql+pymsql+pyodbc://...']
self._db_uri_error(bad_db_uri_strings,
"mlflow.org/docs/latest/tracking.html#storage for format specifications")


class TestSqlAlchemyStoreSqlite(unittest.TestCase):

def _get_store(self, db_uri=''):
Expand Down
14 changes: 13 additions & 1 deletion tests/utils/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlflow.protos.databricks_pb2 import ErrorCode, INVALID_PARAMETER_VALUE
from mlflow.utils.validation import _validate_metric_name, _validate_param_name, \
_validate_tag_name, _validate_run_id, _validate_batch_log_data, \
_validate_batch_log_limits, _validate_experiment_artifact_location
_validate_batch_log_limits, _validate_experiment_artifact_location, _validate_db_type_string

GOOD_METRIC_OR_PARAM_NAMES = [
"a", "Ab-5_", "a/b/c", "a.b.c", ".a", "b.", "a..a/._./o_O/.e.", "a b/c d",
Expand Down Expand Up @@ -119,3 +119,15 @@ def test_validate_experiment_artifact_location():
_validate_experiment_artifact_location(None)
with pytest.raises(MlflowException):
_validate_experiment_artifact_location('runs:/blah/bleh/blergh')


def test_db_type():
for db_type in ["mysql", "mssql", "postgresql", "sqlite"]:
# should not raise an exception
_validate_db_type_string(db_type)

# error cases
for db_type in ["MySQL", "mongo", "cassandra", "sql", ""]:
with pytest.raises(MlflowException) as e:
_validate_db_type_string(db_type)
assert "Invalid database engine" in e.value.message

0 comments on commit e68b6c3

Please sign in to comment.