Skip to content

WIP Support for INTERVAL datatype to SQLAlchemy #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType
from trino.sqlalchemy.datatype import INTERVAL, MAP, ROW, SQLType


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,6 +40,11 @@ def _assert_sqltype(this: SQLType, that: SQLType):
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

elif isinstance(this, INTERVAL):
assert this.precision == that.precision
assert this.fields == that.fields

else:
assert str(this) == str(that)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW
from trino.sqlalchemy.datatype import INTERVAL, MAP, ROW


@pytest.mark.parametrize(
Expand Down Expand Up @@ -179,6 +179,8 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
"interval '6' month": INTERVAL(precision=6, fields="month"),
"interval '14' day": INTERVAL(precision=14, fields="day"),
}


Expand Down
12 changes: 12 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ def visit_BLOB(self, type_, **kw):
def visit_DATETIME(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)

def visit_INTERVAL(self, type_, **kw):
text = "INTERVAL"
if type_.adapt_datatype:
if type_.fields in ("month", "year"):
return "INTERVAL YEAR TO MONTH"
elif type_.fields in ("second", "minute", "hour", "day"):
return "INTERVAL DAY TO SECOND"
if type_.precision is not None:
text += " '%d'" % type_.precision
if type_.fields is not None:
text += " " + type_.fields
return text

class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
81 changes: 65 additions & 16 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import datetime as dt
from typing import Iterator, List, Optional, Tuple, Type, Union

from sqlalchemy import util
Expand All @@ -23,6 +24,44 @@ class DOUBLE(sqltypes.Float):
__visit_name__ = "DOUBLE"


class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
__visit_name__ = "INTERVAL"
native = True

def __init__(self, precision=None, fields=None, adapt_datatype: bool = False):
"""Construct an INTERVAL.

:param precision: integer precision value
:param fields: string fields specifier. allows storage of fields
to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"SECOND"``,
etc.
:param adapt_datatype: allows conversion from data type value to column data type
"""
self.precision = precision
self.fields = fields
self.adapt_datatype = adapt_datatype

def adapt_value_to_datatype(self):
return INTERVAL(
precision=self.precision,
fields=self.fields,
adapt_datatype=True)

@property
def _type_affinity(self):
return sqltypes.Interval

def as_generic(self, allow_nulltype=False):
return sqltypes.Interval(native=True, second_precision=self.precision)

@property
def python_type(self):
return dt.timedelta

def coerce_compared_value(self, op, value):
return self


class MAP(TypeEngine):
__visit_name__ = "MAP"

Expand Down Expand Up @@ -79,9 +118,7 @@ def python_type(self):
"date": sqltypes.DATE,
"time": sqltypes.TIME,
"timestamp": sqltypes.TIMESTAMP,
# 'interval year to month':
# 'interval day to second':
#
"interval": INTERVAL,
# === Structural ===
# 'array': ARRAY,
# 'map': MAP
Expand All @@ -108,13 +145,13 @@ def unquote(string: str, quote: str = '"', escape: str = "\\") -> str:


def aware_split(
string: str,
delimiter: str = ",",
maxsplit: int = -1,
quote: str = '"',
escaped_quote: str = r"\"",
open_bracket: str = "(",
close_bracket: str = ")",
string: str,
delimiter: str = ",",
maxsplit: int = -1,
quote: str = '"',
escaped_quote: str = r"\"",
open_bracket: str = "(",
close_bracket: str = ")",
) -> Iterator[str]:
"""
A split function that is aware of quotes and brackets/parentheses.
Expand Down Expand Up @@ -158,29 +195,30 @@ def aware_split(

def parse_sqltype(type_str: str) -> TypeEngine:
type_str = type_str.strip().lower()
match = re.match(r"^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?", type_str)
match = re.match(r"^(?P<type>\w+)\s*(?:[\(|'](?P<precision>.*)[\)|'])?(?:[ ](?P<fields>.+))?", type_str)
if not match:
util.warn(f"Could not parse type name '{type_str}'")
return sqltypes.NULLTYPE
type_name = match.group("type")
type_opts = match.group("options")
type_precision = match.group("precision")
type_fields = match.group("fields")

if type_name == "array":
item_type = parse_sqltype(type_opts)
item_type = parse_sqltype(type_precision)
if isinstance(item_type, sqltypes.ARRAY):
# Multi-dimensions array is normalized in SQLAlchemy, e.g:
# `ARRAY(ARRAY(INT))` in Trino SQL will become `ARRAY(INT(), dimensions=2)` in SQLAlchemy
dimensions = (item_type.dimensions or 1) + 1
return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions)
return sqltypes.ARRAY(item_type)
elif type_name == "map":
key_type_str, value_type_str = aware_split(type_opts)
key_type_str, value_type_str = aware_split(type_precision)
key_type = parse_sqltype(key_type_str)
value_type = parse_sqltype(value_type_str)
return MAP(key_type, value_type)
elif type_name == "row":
attr_types: List[Tuple[Optional[str], SQLType]] = []
for attr in aware_split(type_opts):
for attr in aware_split(type_precision):
attr_name, attr_type_str = aware_split(attr.strip(), delimiter=" ", maxsplit=1)
attr_name = unquote(attr_name)
attr_type = parse_sqltype(attr_type_str)
Expand All @@ -191,7 +229,18 @@ def parse_sqltype(type_str: str) -> TypeEngine:
util.warn(f"Did not recognize type '{type_name}'")
return sqltypes.NULLTYPE
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
type_args = [int(o.strip()) for o in type_precision.split(",")] if type_precision else []

if type_name == "interval":
if type_fields not in ("second", "minute", "hour", "day", "month", "year"):
util.warn(f"Did not recognize field type '{type_fields}'")
return sqltypes.NULLTYPE
type_kwargs: Dict[str, Any] = dict(
precision=int(type_precision),
fields=type_fields
)
return type_class(**type_kwargs)

if type_name in ("time", "timestamp"):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
Expand Down