Skip to content

[PECOBLR-330] Support for complex params #559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/databricks/sql/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
TimestampNTZParameter,
TinyIntParameter,
DecimalParameter,
MapParameter,
ArrayParameter,
)
136 changes: 125 additions & 11 deletions src/databricks/sql/parameters/native.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
import decimal
from enum import Enum, auto
from typing import Optional, Sequence
from typing import Optional, Sequence, Any

from databricks.sql.exc import NotSupportedError
from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
TSparkParameterValue,
TSparkParameterValueArg,
)

import datetime
Expand Down Expand Up @@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):


TAllowedParameterValue = Union[
str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None
str,
int,
float,
datetime.datetime,
datetime.date,
bool,
decimal.Decimal,
None,
list,
dict,
tuple,
]


Expand Down Expand Up @@ -82,6 +93,7 @@ class DbsqlParameterBase:

CAST_EXPR: str
name: Optional[str]
value: Any

def as_tspark_param(self, named: bool) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
Expand All @@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
def _tspark_param_value(self):
return TSparkParameterValue(stringValue=str(self.value))

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr())

def _cast_expr(self):
return self.CAST_EXPR

Expand Down Expand Up @@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
CAST_EXPR = DatabricksSupportedType.TINYINT.name


class ArrayParameter(DbsqlParameterBase):
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""

def __init__(self, value: Sequence[Any], name: Optional[str] = None):
"""
:value:
The value to bind for this parameter. This will be casted to a ARRAY.
:name:
If None, your query must contain a `?` marker. Like:

```sql
SELECT * FROM table WHERE field = ?
```
If not None, your query should contain a named parameter marker. Like:
```sql
SELECT * FROM table WHERE field = :my_param
```

The `name` argument to this function would be `my_param`.
"""
self.name = name
self.value = [dbsql_parameter_from_primitive(val) for val in value]

def as_tspark_param(self, named: bool = False) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""

tsp = TSparkParameter(type=self._cast_expr())
tsp.arguments = [val._tspark_value_arg() for val in self.value]

if named:
tsp.name = self.name
tsp.ordinal = False
elif not named:
tsp.ordinal = True
return tsp

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
tva = TSparkParameterValueArg(type=self._cast_expr())
tva.arguments = [val._tspark_value_arg() for val in self.value]
return tva

CAST_EXPR = DatabricksSupportedType.ARRAY.name


class MapParameter(DbsqlParameterBase):
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""

def __init__(self, value: dict, name: Optional[str] = None):
"""
:value:
The value to bind for this parameter. This will be casted to a MAP.
:name:
If None, your query must contain a `?` marker. Like:

```sql
SELECT * FROM table WHERE field = ?
```
If not None, your query should contain a named parameter marker. Like:
```sql
SELECT * FROM table WHERE field = :my_param
```

The `name` argument to this function would be `my_param`.
"""
self.name = name
self.value = [
dbsql_parameter_from_primitive(item)
for key, val in value.items()
for item in (key, val)
]

def as_tspark_param(self, named: bool = False) -> TSparkParameter:
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""

tsp = TSparkParameter(type=self._cast_expr())
tsp.arguments = [val._tspark_value_arg() for val in self.value]
if named:
tsp.name = self.name
tsp.ordinal = False
elif not named:
tsp.ordinal = True
return tsp

def _tspark_value_arg(self):
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
tva = TSparkParameterValueArg(type=self._cast_expr())
tva.arguments = [val._tspark_value_arg() for val in self.value]
return tva

CAST_EXPR = DatabricksSupportedType.MAP.name


class DecimalParameter(DbsqlParameterBase):
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""

Expand Down Expand Up @@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
# its logic

if type(value) is int:
if isinstance(value, bool):
return BooleanParameter(value=value, name=name)
elif isinstance(value, int):
return dbsql_parameter_from_int(value, name=name)
elif type(value) is str:
elif isinstance(value, str):
return StringParameter(value=value, name=name)
elif type(value) is float:
elif isinstance(value, float):
return FloatParameter(value=value, name=name)
elif type(value) is datetime.datetime:
elif isinstance(value, datetime.datetime):
return TimestampParameter(value=value, name=name)
elif type(value) is datetime.date:
elif isinstance(value, datetime.date):
return DateParameter(value=value, name=name)
elif type(value) is bool:
return BooleanParameter(value=value, name=name)
elif type(value) is decimal.Decimal:
elif isinstance(value, decimal.Decimal):
return DecimalParameter(value=value, name=name)
elif isinstance(value, dict):
return MapParameter(value=value, name=name)
elif isinstance(value, Sequence) and not isinstance(value, str):
return ArrayParameter(value=value, name=name)
elif value is None:
return VoidParameter(value=value, name=name)

else:
raise NotSupportedError(
f"Could not infer parameter type from value: {value} - {type(value)} \n"
Expand All @@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
TimestampNTZParameter,
TinyIntParameter,
DecimalParameter,
ArrayParameter,
MapParameter,
]


Expand Down
27 changes: 19 additions & 8 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import decimal
from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from collections.abc import Iterable
from collections.abc import Mapping
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Sequence
import re

import lz4.frame
Expand Down Expand Up @@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
# Taken from PyHive
class ParamEscaper:
_DATE_FORMAT = "%Y-%m-%d"
_TIME_FORMAT = "%H:%M:%S.%f"
_TIME_FORMAT = "%H:%M:%S.%f %z"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timezone will not be there for TIMESTAMP_NTZ param.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, would like to know how we've accounted/tested for NTZ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is not there it will be an empty space, there is already a test suite that inserts NTZ and none NTZ and reads back to compare whether it is equal or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vikrantpuppala @shivam2680 There are already existing tests that insert NTZ and non NTZ values and reads back from table to ensure everything is working as expected -

def test_dbsqlparameter_single(

_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)

def escape_args(self, parameters):
Expand Down Expand Up @@ -458,13 +458,22 @@ def escape_string(self, item):
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))

def escape_sequence(self, item):
l = map(str, map(self.escape_item, item))
return "(" + ",".join(l) + ")"
l = map(self.escape_item, item)
l = list(map(str, l))
return "ARRAY(" + ",".join(l) + ")"

def escape_mapping(self, item):
l = map(
self.escape_item,
(element for key, value in item.items() for element in (key, value)),
)
l = list(map(str, l))
return "MAP(" + ",".join(l) + ")"

def escape_datetime(self, item, format, cutoff=0):
dt_str = item.strftime(format)
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
return "'{}'".format(formatted)
return "'{}'".format(formatted.strip())

def escape_decimal(self, item):
return str(item)
Expand All @@ -476,14 +485,16 @@ def escape_item(self, item):
return self.escape_number(item)
elif isinstance(item, str):
return self.escape_string(item)
elif isinstance(item, Iterable):
return self.escape_sequence(item)
elif isinstance(item, datetime.datetime):
return self.escape_datetime(item, self._DATETIME_FORMAT)
elif isinstance(item, datetime.date):
return self.escape_datetime(item, self._DATE_FORMAT)
elif isinstance(item, decimal.Decimal):
return self.escape_decimal(item)
elif isinstance(item, Sequence):
return self.escape_sequence(item)
elif isinstance(item, Mapping):
return self.escape_mapping(item)
else:
raise exc.ProgrammingError("Unsupported object {}".format(item))

Expand Down
44 changes: 34 additions & 10 deletions tests/e2e/test_complex_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from numpy import ndarray
from typing import Sequence

from tests.e2e.test_driver import PySQLPytestTestCase

Expand All @@ -14,50 +15,73 @@ def table_fixture(self, connection_details):
# Create the table
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pysql_e2e_test_complex_types_table (
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
array_col ARRAY<STRING>,
map_col MAP<STRING, INTEGER>,
struct_col STRUCT<field1: STRING, field2: INTEGER>
)
struct_col STRUCT<field1: STRING, field2: INTEGER>,
array_array_col ARRAY<ARRAY<STRING>>,
array_map_col ARRAY<MAP<STRING, INTEGER>>,
map_array_col MAP<STRING, ARRAY<STRING>>
) USING DELTA
"""
)
# Insert a record
cursor.execute(
"""
INSERT INTO pysql_e2e_test_complex_types_table
INSERT INTO pysql_test_complex_types_table
VALUES (
ARRAY('a', 'b', 'c'),
MAP('a', 1, 'b', 2, 'c', 3),
NAMED_STRUCT('field1', 'a', 'field2', 1)
NAMED_STRUCT('field1', 'a', 'field2', 1),
ARRAY(ARRAY('a','b','c')),
ARRAY(MAP('a', 1, 'b', 2, 'c', 3)),
MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e'))
)
"""
)
yield
# Clean up the table after the test
cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table")
cursor.execute("DELETE FROM pysql_test_complex_types_table")

@pytest.mark.parametrize(
"field,expected_type",
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
[
("array_col", ndarray),
("map_col", list),
("struct_col", dict),
("array_array_col", ndarray),
("array_map_col", ndarray),
("map_array_col", list),
],
)
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
"""Confirms the return types of a complex type field when reading as arrow"""

with self.cursor() as cursor:
result = cursor.execute(
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], expected_type)

@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
@pytest.mark.parametrize(
"field",
[
("array_col"),
("map_col"),
("struct_col"),
("array_array_col"),
("array_map_col"),
("map_array_col"),
],
)
def test_read_complex_types_as_string(self, field, table_fixture):
"""Confirms the return type of a complex type that is returned as a string"""
with self.cursor(
extra_params={"_use_arrow_native_complex_types": False}
) as cursor:
result = cursor.execute(
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
).fetchone()

assert isinstance(result[field], str)
8 changes: 6 additions & 2 deletions tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
raise KeyboardInterrupt("Simulated interrupt")
finally:
if conn is not None:
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
assert (
not conn.open
), "Connection should be closed after KeyboardInterrupt"

def test_cursor_close_properly_closes_operation(self):
"""Test that Cursor.close() properly closes the active operation handle on the server."""
Expand All @@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self):
raise KeyboardInterrupt("Simulated interrupt")
finally:
if cursor is not None:
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
assert (
not cursor.open
), "Cursor should be closed after KeyboardInterrupt"

def test_nested_cursor_context_managers(self):
"""Test that nested cursor context managers properly close operations on the server."""
Expand Down
Loading
Loading