Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): improved dtype inference/refinement for read_database results #15126

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 186 additions & 5 deletions py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,191 @@ def _map_py_type_to_dtype(
dtype if nested is None else dtype(_map_py_type_to_dtype(nested)) # type: ignore[operator]
)

msg = "invalid type"
msg = f"unrecognised Python type: {python_dtype!r}"
raise TypeError(msg)


def _timeunit_from_precision(precision: int | str | None) -> str | None:
"""Return `time_unit` from integer precision value."""
from math import ceil

if not precision:
return None
elif isinstance(precision, str):
if precision.isdigit():
precision = int(precision)
elif (precision := precision.lower()) in ("s", "ms", "us", "ns"):
return "ms" if precision == "s" else precision
try:
n = min(max(3, int(ceil(precision / 3)) * 3), 9) # type: ignore[operator]
return {3: "ms", 6: "us", 9: "ns"}.get(n)
except TypeError:
return None


def _infer_dtype_from_database_typename(
value: str,
*,
raise_unmatched: bool = True,
) -> PolarsDataType | None:
"""Attempt to infer Polars dtype from database cursor `type_code` string value."""
dtype: PolarsDataType | None = None

# normalise string name/case (eg: 'IntegerType' -> 'INTEGER')
original_value = value
value = value.upper().replace("TYPE", "")

# extract optional type modifier (eg: 'VARCHAR(64)' -> '64')
if re.search(r"\([\w,: ]+\)$", value):
modifier = value[value.find("(") + 1 : -1]
value = value.split("(")[0]
elif (
not value.startswith(("<", ">")) and re.search(r"\[[\w,\]\[: ]+]$", value)
) or value.endswith(("[S]", "[MS]", "[US]", "[NS]")):
modifier = value[value.find("[") + 1 : -1]
value = value.split("[")[0]
else:
modifier = ""

# array dtypes
array_aliases = ("ARRAY", "LIST", "[]")
if value.endswith(array_aliases) or value.startswith(array_aliases):
for a in array_aliases:
value = value.replace(a, "", 1) if value else ""

nested: PolarsDataType | None = None
if not value and modifier:
nested = _infer_dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
else:
if inner_value := _infer_dtype_from_database_typename(
value[1:-1]
if (value[0], value[-1]) == ("<", ">")
else re.sub(r"\W", "", re.sub(r"\WOF\W", "", value)),
raise_unmatched=False,
):
nested = inner_value
elif modifier:
nested = _infer_dtype_from_database_typename(
value=modifier,
raise_unmatched=False,
)
if nested:
dtype = List(nested)

# float dtypes
elif value.startswith("FLOAT") or ("DOUBLE" in value) or (value == "REAL"):
dtype = (
Float32
if value == "FLOAT4"
or (value.endswith(("16", "32")) or (modifier in ("16", "32")))
else Float64
)

# integer dtypes
elif ("INTERVAL" not in value) and (
value.startswith(("INT", "UINT", "UNSIGNED"))
or value.endswith(("INT", "SERIAL"))
or ("INTEGER" in value)
or value == "ROWID"
):
sz: Any
if "LARGE" in value or value.startswith("BIG") or value == "INT8":
sz = 64
elif "MEDIUM" in value or value in ("INT4", "SERIAL"):
sz = 32
elif "SMALL" in value or value == "INT2":
sz = 16
elif "TINY" in value:
sz = 8
else:
sz = None

sz = modifier if (not sz and modifier) else sz
if not isinstance(sz, int):
sz = int(sz) if isinstance(sz, str) and sz.isdigit() else None
if (
("U" in value and "MEDIUM" not in value)
or ("UNSIGNED" in value)
or value == "ROWID"
):
dtype = _integer_dtype_from_nbits(sz, unsigned=True, default=UInt64)
else:
dtype = _integer_dtype_from_nbits(sz, unsigned=False, default=Int64)

# decimal dtypes
elif (is_dec := ("DECIMAL" in value)) or ("NUMERIC" in value):
if "," in modifier:
prec, scale = modifier.split(",")
dtype = Decimal(int(prec), int(scale))
else:
dtype = Decimal if is_dec else Float64

# string dtypes
elif (
any(tp in value for tp in ("VARCHAR", "STRING", "TEXT", "UNICODE"))
or value.startswith(("STR", "CHAR", "NCHAR", "UTF"))
or value.endswith(("_UTF8", "_UTF16", "_UTF32"))
):
dtype = String

# binary dtypes
elif value in ("BYTEA", "BYTES", "BLOB", "CLOB", "BINARY"):
dtype = Binary

# boolean dtypes
elif value.startswith("BOOL"):
dtype = Boolean

# temporal dtypes
elif value.startswith(("DATETIME", "TIMESTAMP")) and not (value.endswith("[D]")):
if any((tz in value.replace(" ", "")) for tz in ("TZ", "TIMEZONE")):
if "WITHOUT" not in value:
return None # there's a timezone, but we don't know what it is
unit = _timeunit_from_precision(modifier) if modifier else "us"
dtype = Datetime(time_unit=(unit or "us")) # type: ignore[arg-type]

elif re.sub(r"\d", "", value) in ("INTERVAL", "TIMEDELTA"):
dtype = Duration

elif value in ("DATE", "DATE32", "DATE64"):
dtype = Date

elif value in ("TIME", "TIME32", "TIME64"):
dtype = Time

if not dtype and raise_unmatched:
msg = f"cannot infer dtype from {original_value!r} string value"
raise ValueError(msg)

return dtype


@functools.lru_cache(8)
def _integer_dtype_from_nbits(
bits: int,
*,
unsigned: bool,
default: PolarsDataType | None = None,
) -> PolarsDataType | None:
dtype = {
(8, False): Int8,
(8, True): UInt8,
(16, False): Int16,
(16, True): UInt16,
(32, False): Int32,
(32, True): UInt32,
(64, False): Int64,
(64, True): UInt64,
}.get((bits, unsigned), None)

if dtype is None and default is not None:
return default
return dtype


def is_polars_dtype(dtype: Any, *, include_unknown: bool = False) -> bool:
"""Indicate whether the given input is a Polars dtype, or dtype specialization."""
try:
Expand Down Expand Up @@ -415,10 +596,10 @@ def py_type_to_dtype(
try:
return _map_py_type_to_dtype(data_type)
except (KeyError, TypeError): # pragma: no cover
if not raise_unmatched:
return None
msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})"
raise ValueError(msg) from None
if raise_unmatched:
msg = f"cannot infer dtype from {data_type!r} (type: {type(data_type).__name__!r})"
raise ValueError(msg) from None
return None


def py_type_to_arrow_type(dtype: PythonDataType) -> pa.lib.DataType:
Expand Down
110 changes: 94 additions & 16 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@

import re
import sys
from contextlib import suppress
from importlib import import_module
from inspect import Parameter, signature
from inspect import Parameter, isclass, signature
from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence, TypedDict, overload

from polars._utils.deprecation import issue_deprecation_warning
from polars.convert import from_arrow
from polars.datatypes import N_INFER_DEFAULT
from polars.datatypes import (
INTEGER_DTYPES,
N_INFER_DEFAULT,
UNSIGNED_INTEGER_DTYPES,
Decimal,
Float32,
Float64,
)
from polars.datatypes.convert import (
_infer_dtype_from_database_typename,
_integer_dtype_from_nbits,
_map_py_type_to_dtype,
)
from polars.exceptions import InvalidOperationError, UnsuitableSQLError

if TYPE_CHECKING:
Expand All @@ -26,6 +39,7 @@
from typing_extensions import Self

from polars import DataFrame
from polars.datatypes import PolarsDataType
from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict

try:
Expand Down Expand Up @@ -295,17 +309,19 @@ def _from_rows(
if hasattr(self.result, "fetchall"):
if self.driver_name == "sqlalchemy":
if hasattr(self.result, "cursor"):
cursor_desc = {d[0]: d[1] for d in self.result.cursor.description}
cursor_desc = {d[0]: d[1:] for d in self.result.cursor.description}
elif hasattr(self.result, "_metadata"):
cursor_desc = {k: None for k in self.result._metadata.keys}
else:
msg = f"Unable to determine metadata from query result; {self.result!r}"
raise ValueError(msg)
else:
cursor_desc = {d[0]: d[1] for d in self.result.description}
cursor_desc = {d[0]: d[1:] for d in self.result.description}

# TODO: refine types based on the cursor description's type_code,
# if/where available? (for now, we just read the column names)
schema_overrides = self._inject_type_overrides(
description=cursor_desc,
schema_overrides=(schema_overrides or {}),
)
result_columns = list(cursor_desc)
frames = (
DataFrame(
Expand All @@ -324,17 +340,79 @@ def _from_rows(
return frames if iter_batches else next(frames) # type: ignore[arg-type]
return None

def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
def _inject_type_overrides(
self,
description: dict[str, Any],
schema_overrides: SchemaDict,
) -> SchemaDict:
"""Attempt basic dtype inference from a cursor description."""
# note: this is limited; the `type_code` property may contain almost anything,
# from strings or python types to driver-specific codes, classes, enums, etc.
# currently we only do additional inference from string/python type values.
# (further refinement requires per-driver module knowledge and lookups).

dtype: PolarsDataType | None = None
for nm, desc in description.items():
if desc is None:
continue
elif nm not in schema_overrides:
type_code, _disp_size, internal_size, prec, scale, _null_ok = desc
if isclass(type_code):
# python types, eg: int, float, str, etc
with suppress(TypeError):
dtype = _map_py_type_to_dtype(type_code) # type: ignore[arg-type]

elif isinstance(type_code, str):
# database/sql type names, eg: "VARCHAR", "NUMERIC", "BLOB", etc
dtype = _infer_dtype_from_database_typename(
value=type_code,
raise_unmatched=False,
)

if dtype is not None:
# check additional cursor information to improve dtype inference
if dtype == Float64 and internal_size == 4:
dtype = Float32

elif dtype in INTEGER_DTYPES and internal_size in (2, 4, 8):
bits = internal_size * 8
dtype = _integer_dtype_from_nbits(
bits,
unsigned=(dtype in UNSIGNED_INTEGER_DTYPES),
default=dtype,
)
elif (
dtype == Decimal
and isinstance(prec, int)
and isinstance(scale, int)
and prec <= 38
and scale <= 38
):
dtype = Decimal(prec, scale)

if dtype is not None:
schema_overrides[nm] = dtype # type: ignore[index]

return schema_overrides

def _normalise_cursor(self, conn: Any) -> Cursor:
"""Normalise a connection object such that we have the query executor."""
if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine":
self.can_close_cursor = True
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr, return-value]
if self.driver_name == "sqlalchemy":
self.can_close_cursor = (conn_type := type(conn).__name__) == "Engine"
if conn_type == "Session":
return conn
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr, return-value]
# where possible, use the raw connection to access arrow integration
if conn.engine.driver == "databricks-sql-python":
self.driver_name = "databricks"
return conn.engine.raw_connection().cursor()
elif conn.engine.driver == "duckdb_engine":
self.driver_name = "duckdb"
return conn.engine.raw_connection().driver_connection.c
elif conn_type == "Engine":
return conn.connect()
else:
return conn

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
Expand All @@ -344,7 +422,7 @@ def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:

elif hasattr(conn, "execute"):
# can execute directly (given cursor, sqlalchemy connection, etc)
return conn # type: ignore[return-value]
return conn

msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method"
raise TypeError(msg)
Expand Down
Loading
Loading