Skip to content

Changes to support MSSQL #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dev/dev.env
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ VERTICA_DB_NAME=vertica
# leave VMART_DIR and VMART_ETL_SCRIPT empty.
VMART_DIR=
VMART_ETL_SCRIPT=

ACCEPT_EULA=Y
MSSQL_SA_PASSWORD=<password!CAP>
17 changes: 16 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,28 @@ services:
networks:
- local


mssql:
container_name: dd-mssql
image: mcr.microsoft.com/mssql/server:2022-latest
restart: always
volumes:
- mssql-data:/var/opt/mssql/data:delegated
ports:
- '8020:1433'
expose:
- 8020
env_file:
- dev/dev.env
tty: true
networks:
- local

volumes:
postgresql-data:
mysql-data:
clickhouse-data:
vertica-data:
mssql-data:

networks:
local:
Expand Down
1 change: 1 addition & 0 deletions docs/supported-databases.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
| Clickhouse | 💛 | `clickhouse://<username>:<password>@<hostname>:9000/<database>` |
| Vertica | 💛 | `vertica://<username>:<password>@<hostname>:5433/<database>` |
| DuckDB | 💛 | |
| MsSQL | ⏳ | `pymssql://<user>:<password>@<host>:<port>/<database>` |
| ElasticSearch | 📝 | |
| Planetscale | 📝 | |
| Pinot | 📝 | |
Expand Down
320 changes: 224 additions & 96 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ textual = {version=">=0.9.1", optional=true}
textual-select = {version="*", optional=true}
pygments = {version=">=2.13.0", optional=true}
prompt-toolkit = {version=">=3.0.36", optional=true}
pymssql = {version=">=2.3.2", optional=true}

[tool.poetry.dev-dependencies]
parameterized = "*"
Expand All @@ -56,6 +57,7 @@ trino = ">=0.314.0"
presto-python-client = "*"
clickhouse-driver = "*"
vertica-python = "*"
pymssql = "*"

[tool.poetry.extras]
mysql = ["mysql-connector-python"]
Expand All @@ -69,6 +71,7 @@ clickhouse = ["clickhouse-driver"]
vertica = ["vertica-python"]
duckdb = ["duckdb"]
tui = ["textual", "textual-select", "pygments", "prompt-toolkit"]
mssql = ["pymssql"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
1 change: 1 addition & 0 deletions sqeleton/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSQL

connect = Connect()
2 changes: 2 additions & 0 deletions sqeleton/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .clickhouse import Clickhouse
from .vertica import Vertica
from .duckdb import DuckDB
from .mssql import MsSQL


@dataclass
Expand Down Expand Up @@ -87,6 +88,7 @@ def match_path(self, dsn):
"trino": Trino,
"clickhouse": Clickhouse,
"vertica": Vertica,
"pymssql": MsSQL,
}


Expand Down
230 changes: 211 additions & 19 deletions sqeleton/databases/mssql.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,217 @@
# class MsSQL(ThreadedDatabase):
# "AKA sql-server"
from typing import List
from datetime import datetime
from ..abcs.database_types import (
DbPath,
Timestamp,
TimestampTZ,
Float,
Decimal,
Integer,
TemporalType,
Text,
FractionalType,
Boolean,
Date,
)
from typing import Dict
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, Mixin_Schema
from ..abcs import Compilable
from ..queries import this, table, Select, SKIP
from ..queries.ast_classes import ForeignKey, TablePath
from .base import TIMESTAMP_PRECISION_POS, Mixin_RandomSample

# def __init__(self, host, port, user, password, *, database, thread_count, **kw):
# args = dict(server=host, port=port, database=database, user=user, password=password, **kw)
# self._args = {k: v for k, v in args.items() if v is not None}
SESSION_TIME_ZONE = None # Changed by the tests

# super().__init__(thread_count=thread_count)

# def create_connection(self):
# mssql = import_mssql()
# try:
# return mssql.connect(**self._args)
# except mssql.Error as e:
# raise ConnectError(*e.args) from e
@import_helper("mssql")
def import_mssql():
import pymssql

# def quote(self, s: str):
# return f"[{s}]"
return pymssql

# def md5_as_int(self, s: str) -> str:
# return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))"
# # return f"CONVERT(bigint, (CHECKSUM({s})))"

# def to_string(self, s: str):
# return f"CONVERT(varchar, {s})"
class Mixin_MD5(AbstractMixin_MD5):
def md5_as_int(self, s: str) -> str:
return f"CONVERT(decimal(38,0), CONVERT(bigint, HashBytes('MD5', {s}), 2))"

class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
timestamp = f"convert(varchar(26), {value} AT TIME ZONE 'UTC', 25)"
return (
f"LEFT({timestamp} + REPLICATE(' ', {coltype.precision}), {TIMESTAMP_PRECISION_POS+6})"
)

def normalize_number(self, value: str, coltype: FractionalType) -> str:
return self.to_string(f"convert(varchar, convert(decimal(38, {coltype.precision}), {value}))")

def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
return self.to_string(f"convert(varchar, {value})")

class Mixin_Schema(AbstractMixin_Schema):
def table_information(self) -> TablePath:
return table("information_schema", "tables")

def list_tables(self, table_schema: str, like: Compilable = None) -> Select:
return (
self.table_information()
.where(
this.table_schema == table_schema if table_schema is not None else SKIP,
this.table_name.like(like) if like is not None else SKIP,
this.table_type == "BASE TABLE",
)
.select(this.table_name)
)


class MsSQLDialect(BaseDialect, Mixin_Schema):
name = "MsSQL"
ROUNDS_ON_PREC_LOSS = True
SUPPORTS_PRIMARY_KEY = True
SUPPORTS_INDEXES = True
MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample}
AT_TIMEZONE = False

TYPE_CLASSES = {
# Numbers
"tinyint": Integer,
"smallint": Integer,
"int": Integer,
"bigint": Integer,
"decimal": Decimal,
"numeric": Decimal,
"money": Decimal,
"smallmoney": Decimal,
"float": Float,
"real": Float,
# Timestamps
"date": Date,
"time": Timestamp,
"datetime2": Timestamp,
"datetimeoffset": TimestampTZ,
"datetime": Timestamp,
"smalldatetime": Date,
# Text
"char": Text,
"varchar": Text,
"text": Text,
"nchar": Text,
"nvarchar": Text,
"ntext": Text,
# Boolean
"BIT": Boolean,
}

# TSQL has EXPLAIN for Azure SQL Data warehouse
# But not yet included for the regular RDBMS SQL Server
def explain_as_text(self, query: str) -> str:
return f"""SET SHOWPLAN_ALL ON;
GO
{query}
GO
SET SHOWPLAN_ALL ON;
GO"""

def quote(self, s: str):
return f'"{s}"'

def to_string(self, s: str):
return f"CONVERT(VARCHAR(MAX), {s})"

def concat(self, items: List[str]) -> str:
joined_exprs = ", ".join(items)
return f"CONCAT({joined_exprs})"

def _convert_db_precision_to_digits(self, p: int) -> int:
return super()._convert_db_precision_to_digits(p) - 2

# Datetime is stored as UTC by default in MsSQL
# There is no current way to enforce a timezone for a session
def set_timezone_to_utc(self) -> str:
return ""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always UTC? If so, worth documenting.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, I have added code commentary. Is there an additional location that this should be documented?


def current_timestamp(self) -> str:
return "SYSUTCDATETIME()"

def type_repr(self, t) -> str:
if isinstance(t, TimestampTZ):
return f"datetimeoffset"
elif isinstance(t, ForeignKey):
return self.type_repr(t.type)
elif isinstance(t, type):
try:
return {
str: "NVARCHAR(MAX)",
bool: "BIT",
datetime: "datetime2",
}[t]
except KeyError:
return super().type_repr(t)

super().type_repr(t)

class MsSQL(ThreadedDatabase):
"AKA sql-server"
dialect = MsSQLDialect()
SUPPORTS_ALPHANUMS = False
SUPPORTS_UNIQUE_CONSTAINT = True
CONNECT_URI_HELP = "pymssql://<user>:<password>@<host>:<port>/<database>"
CONNECT_URI_PARAMS = ["database"]

def __init__(self, host, port, user, password, *, database, thread_count, **kw):
args = dict(server=host, port=port, database=database, user=user, password=password, conn_properties=['SET QUOTED_IDENTIFIER ON;'], **kw)
self._args = {k: v for k, v in args.items() if v is not None}

super().__init__(thread_count=thread_count)

def create_connection(self):
self.mssql = import_mssql()
try:
return self.mssql.connect(**self._args)
except self.mssql.Error as e:
raise ConnectError(*e.args) from e

def _normalize_table_path(self, path: DbPath) -> DbPath:
if len(path) == 1:
return None, self.default_schema, path[0]
elif len(path) == 2:
return None, path[0], path[1]
elif len(path) == 3:
return path

raise ValueError(
f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
)

def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""

database, schema, name = self._normalize_table_path(path)

info_schema_path = ["information_schema", "COLUMNS"]
if database:
info_schema_path.insert(0, database)

if schema == None:
sql_code = (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}'"
)
else:
sql_code = (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)

return sql_code

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
rows = self.query(self.select_table_schema(path), list)
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

d = {r[0]: r for r in rows}
assert len(d) == len(rows)
return d
2 changes: 2 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# vertica uri provided for docker - "vertica://vertica:Password1@localhost:5433/vertica"
TEST_VERTICA_CONN_STRING: str = os.environ.get("VERTICA_URI")
TEST_DUCKDB_CONN_STRING: str = "duckdb://main:@:memory:"
TEST_MSSQL_CONN_STRING: str = "pymssql://sa:<password!CAP>@localhost:8020/master"


DEFAULT_N_SAMPLES = 50
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_git_revision_short_hash() -> str:
db.Clickhouse: TEST_CLICKHOUSE_CONN_STRING,
db.Vertica: TEST_VERTICA_CONN_STRING,
db.DuckDB: TEST_DUCKDB_CONN_STRING,
db.MsSQL: TEST_MSSQL_CONN_STRING
}

_database_instances = {}
Expand Down
7 changes: 4 additions & 3 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
dbs.Presto,
dbs.Trino,
dbs.Vertica,
dbs.MsSQL,
}

test_each_database: Callable = make_test_each_database_in_list(TEST_DATABASES)
Expand Down Expand Up @@ -163,13 +164,13 @@ def test_foreign_key(self):
@test_each_database
class TestThreePartIds(unittest.TestCase):
def test_three_part_support(self):
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB]:
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.MsSQL]:
self.skipTest("Limited support for 3 part ids")

table_name = "tbl_" + random_table_suffix()
db = get_conn(self.db_cls)
db_res = db.query("SELECT CURRENT_DATABASE()")
schema_res = db.query("SELECT CURRENT_SCHEMA()")
db_res = db.query("SELECT CURRENT_DATABASE()") if self.db_cls != dbs.MsSQL else db.query("SELECT DB_NAME() AS [Current Database]")
schema_res = db.query("SELECT CURRENT_SCHEMA()") if self.db_cls != dbs.MsSQL else db.query("SELECT SCHEMA_NAME()")
db_name = db_res.rows[0][0]
schema_name = schema_res.rows[0][0]

Expand Down
Loading