Skip to content

Commit

Permalink
feat(python): additional cursor-level dtype inference/refinement for …
Browse files Browse the repository at this point in the history
…`read_database`
  • Loading branch information
alexander-beedie committed Mar 18, 2024
1 parent f8ade71 commit 66690b9
Show file tree
Hide file tree
Showing 3 changed files with 362 additions and 21 deletions.
191 changes: 186 additions & 5 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,191 @@ def _map_py_type_to_dtype(
dtype if nested is None else dtype(_map_py_type_to_dtype(nested)) # type: ignore[operator]
)

msg = "invalid type"
msg = f"unrecognised Python type: {python_dtype!r}"
raise TypeError(msg)


def _timeunit_from_precision(precision: int | str | None) -> str | None:
"""Return `time_unit` from integer precision value."""
from math import ceil

if not precision:
return None
elif isinstance(precision, str):
if precision.isdigit():
precision = int(precision)
elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
return "ms" if precision == "s" else precision
try:
n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
return {3: "ms", 6: "us", 9: "ns"}.get(n)
except TypeError:
return None


def _infer_dtype_from_database_typename(
value: str,
*,
raise_unmatched: bool = True,
) -> PolarsDataType | None:
"""Attempt to infer Polars dtype from database cursor `type_code` string value."""
dtype: PolarsDataType | None = None

# normalise string name (eg: 'IntegerType' -> 'INTEGER')
original_value = value
value = value.upper().removesuffix("TYPE")

# extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
if re.search(r"\([\w,: ]+\)$", value):
modifier = value[value.find("(") + 1 : -1]
value = value.split("(")[0]
elif (
not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
modifier = value[value.find("[") + 1 : -1]
value = value.split("[")[0]
else:
modifier = ""

# array dtypes
array_aliases = ("ARRAY", "LIST", "[]")
if value.endswith(array_aliases) or value.startswith(array_aliases):
for a in array_aliases:
value = value.replace(a, "", 1) if value else ""

nested: PolarsDataType | None = None
if not value and modifier:
nested = _infer_dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
else:
if inner_value := _infer_dtype_from_database_typename(
value[1:-1]
if (value[0], value[-1]) == ("<", ">")
else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
raise_unmatched=False,
):
nested = inner_value
elif modifier:
nested = _infer_dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
if nested:
dtype = List(nested)

# float dtypes
elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
dtype = (
Float32
if value == "FLOAT4"
or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
else Float64
)

# integer dtypes
elif ("INTERVAL" not in value) and (
value.startswith(("INT", "UINT", "UNSIGNED"))
or value.endswith(("INT", "SERIAL"))
or ("INTEGER" in value)
or value == "ROWID"
):
sz: Any
if "LARGE" in value or value.startswith("BIG") or value == "INT8":
sz = 64
elif "MEDIUM" in value or value in ("INT4", "SERIAL"):
sz = 32
elif "SMALL" in value or value == "INT2":
sz = 16
elif "TINY" in value:
sz = 8
else:
sz = None

sz = modifier if (not sz and modifier) else sz
if not isinstance(sz, int):
sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
if (
("U" in value and "MEDIUM" not in value)
or ("UNSIGNED" in value)
or value == "ROWID"
):
dtype = _integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
else:
dtype = _integer_dtype_from_nbits(sz, unsigned=False, default=Int64)

# decimal dtypes
elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
if "," in modifier:
prec, scale = modifier.split(",")
dtype = Decimal(int(prec), int(scale))
else:
dtype = Decimal if is_dec else Float64

# string dtypes
elif (
any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
or value.startswith(("STR", "CHAR", "NCHAR", "UTF"))
or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
):
dtype = String

# binary dtypes
elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
dtype = Binary

# boolean dtypes
elif value.startswith("BOOL"):
dtype = Boolean

# temporal dtypes
elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
if "WITHOUT" not in value:
return None # there's a timezone, but we don't know what it is
unit = _timeunit_from_precision(modifier) if modifier else "us"
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]

elif re.sub(r"\d", "", value) in ("INTERVAL", "TIMEDELTA"):
dtype = Duration

elif value in ("DATE", "DATE32", "DATE64"):
dtype = Date

elif value in ("TIME", "TIME32", "TIME64"):
dtype = Time

if not dtype and raise_unmatched:
msg = f"cannot infer dtype from {original_value!r} string value"
raise ValueError(msg)

return dtype


@functools.lru_cache(8)
def _integer_dtype_from_nbits(
bits: int,
*,
unsigned: bool,
default: PolarsDataType | None = None,
) -> PolarsDataType | None:
dtype = {
(8, False): Int8,
(8, True): UInt8,
(16, False): Int16,
(16, True): UInt16,
(32, False): Int32,
(32, True): UInt32,
(64, False): Int64,
(64, True): UInt64,
}.get((bits, unsigned), None)

if dtype is None and default is not None:
return default
return dtype


def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
try:
Expand Down Expand Up @@ -415,10 +596,10 @@ def py_type_to_dtype(
try:
return _map_py_type_to_dtype(data_type)
except (KeyError, TypeError): # pragma: no cover
if not raise_unmatched:
return None
msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})"
raise ValueError(msg) from None
if raise_unmatched:
msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})"
raise ValueError(msg) from None
return None


def py_type_to_arrow_type(dtype: PythonDataType) -> pa.lib.DataType:
Expand Down
109 changes: 93 additions & 16 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@

import re
import sys
from contextlib import suppress
from importlib import import_module
from inspect import Parameter, signature
from inspect import Parameter, isclass, signature
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence, TypedDict, overload

from polars._utils.deprecation import issue_deprecation_warning
from polars.convert import from_arrow
from polars.datatypes import N_INFER_DEFAULT
from polars.datatypes import (
INTEGER_DTYPES,
N_INFER_DEFAULT,
UNSIGNED_INTEGER_DTYPES,
Decimal,
Float32,
Float64,
)
from polars.datatypes.convert import (
_infer_dtype_from_database_typename,
_integer_dtype_from_nbits,
_map_py_type_to_dtype,
)
from polars.exceptions import InvalidOperationError, UnsuitableSQLError

if TYPE_CHECKING:
Expand All @@ -26,6 +39,7 @@
from typing_extensions import Self

from polars import DataFrame
from polars.datatypes import PolarsDataType
from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict

try:
Expand Down Expand Up @@ -295,17 +309,19 @@ def _from_rows(
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {d[0]: d[1] for d in self.result.cursor.description}
cursor_desc = {d[0]: d[1:] for d in self.result.cursor.description}
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)
else:
cursor_desc = {d[0]: d[1] for d in self.result.description}
cursor_desc = {d[0]: d[1:] for d in self.result.description}

# TODO: refine types based on the cursor description's type_code,
# if/where available? (for now, we just read the column names)
schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
result_columns = list(cursor_desc)
frames = (
DataFrame(
Expand All @@ -324,17 +340,78 @@ def _from_rows(
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None

def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
def _inject_type_overrides(
self,
description: dict[str, Any],
schema_overrides: SchemaDict,
) -> SchemaDict:
"""Attempt basic dtype inference from a cursor description."""
# note: this is limited; the `type_code` property may contain almost anything,
# from strings or python types to driver-specific codes, classes, enums, etc.
# currently we only do additional inference from string/python type values.
# (further refinement requires per-driver module knowledge and lookups).

dtype: PolarsDataType | None = None
for nm, desc in description.items():
if desc is None:
continue
elif not schema_overrides.get(nm, None):
type_code, _disp_size, internal_size, prec, scale, _null_ok = desc

if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

if dtype is not None and internal_size and internal_size > 0:
# refine numeric dtypes from additional cursor information
if internal_size == 4 and dtype == Float64:
dtype = Float32
elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(prec, int)
and isinstance(scale, int)
and (prec + scale) <= 38
):
dtype = Decimal(prec, scale)

if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]

return schema_overrides

def _normalise_cursor(self, conn: Any) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine":
self.can_close_cursor = True
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr, return-value]
if self.driver_name == "sqlalchemy":
self.can_close_cursor = (conn_type := type(conn).__name__) == "Engine"
if conn_type == "Session":
return conn
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr, return-value]
# where possible, use the raw connection to access arrow integration
if conn.engine.driver == "databricks-sql-python":
self.driver_name = "databricks"
return conn.engine.raw_connection().cursor()
elif conn.engine.driver == "duckdb_engine":
self.driver_name = "duckdb"
return conn.engine.raw_connection().driver_connection.c
elif conn_type == "Engine":
return conn.connect()
else:
return conn

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
Expand All @@ -344,7 +421,7 @@ def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:

elif hasattr(conn, "execute"):
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn # type: ignore[return-value]
return conn

msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
raise TypeError(msg)
Expand Down
Loading

0 comments on commit 66690b9

Please sign in to comment.