Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Dec 23, 2024
1 parent 9d53968 commit 2b34e81
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
5 changes: 4 additions & 1 deletion opteryx/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def __init__(
right_column: Optional[str] = None,
left_node: Optional[Any] = None,
right_node: Optional[Any] = None,
message: Optional[str] = None,
):
def _format_col(_type, _node, _name):
if _node.node_type == 42:
Expand All @@ -317,7 +318,9 @@ def _format_col(_type, _node, _name):
self.column = column
self.left_column = left_column
self.right_column = right_column
if self.column:
if message:
super().__init__(message)
elif self.column:
super().__init__(
f"Incompatible types for column '{column}': {left_type} and {right_type}"
)
Expand Down
4 changes: 2 additions & 2 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,8 @@ def apply_function(function: str = None, *parameters):
if null_positions.all():
return numpy.array([None] * morsel_size)

if null_positions.any():
# if we have nulls and both columns are numpy arrays, we can speed things
if null_positions.any() and all(isinstance(arr, numpy.ndarray) for arr in parameters):
# if we have nulls and the value array is a numpy arrays, we can speed things
# up by removing the nulls from the calculations, we add the rows back in
# later
valid_positions = ~null_positions
Expand Down
30 changes: 26 additions & 4 deletions opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import simdjson
from pyarrow import compute

from opteryx.exceptions import IncompatibleTypesError
from opteryx.exceptions import SqlError


Expand Down Expand Up @@ -150,13 +151,34 @@ def null_if(col1, col2):
An array where elements from col1 are replaced with None if they match the corresponding elements in col2.
"""
if isinstance(col1, pyarrow.Array):
values = values.to_numpy(False)
col1 = col1.to_numpy(False)
if isinstance(col1, list):
values = numpy.array(values)
col1 = col1.array(col1)
if isinstance(col2, pyarrow.Array):
values = values.to_numpy(False)
col2 = col2.to_numpy(False)
if isinstance(col2, list):
values = numpy.array(values)
col2 = col2.array(col2)

from orso.types import PYTHON_TO_ORSO_MAP
from orso.types import OrsoTypes

def get_first_non_null_type(array):
for item in array:
if item is not None:
return PYTHON_TO_ORSO_MAP.get(type(item), OrsoTypes._MISSING_TYPE)
return OrsoTypes.NULL

col1_type = get_first_non_null_type(col1.tolist())
col2_type = get_first_non_null_type(col2.tolist())

if col1_type != col2_type:
print(col1_type, col2_type)

raise IncompatibleTypesError(
left_type=col1_type,
right_type=col2_type,
message=f"`NULLIF` called with input arrays of different types, {col1_type} and {col2_type}.",
)

# Create a mask where elements in col1 are equal to col2
mask = col1 == col2
Expand Down
4 changes: 4 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,10 @@
("SELECT DISTINCT l FROM (SELECT split('a b c d e f g h i j', ' ') as letters) as plet CROSS JOIN UNNEST (letters) as l", 10, 1, None),
# 2112
("SELECT id FROM $planets WHERE surface_pressure / surface_pressure is null", 5, 1, None),
#2144
("SELECT town, LENGTH(NULLIF(town, 'Inglewood')) FROM (SELECT birth_place->'town' AS town FROM $astronauts) AS T", 357, 2, None),
("SELECT town, LENGTH(NULLIF(town, b'Inglewood')) FROM (SELECT birth_place->>'town' AS town FROM $astronauts) AS T", 357, 2, None),
("SELECT town, LENGTH(NULLIF(town, 'Inglewood')) FROM (SELECT birth_place->>'town' AS town FROM $astronauts) AS T", None, None, IncompatibleTypesError),
]
# fmt:on

Expand Down

0 comments on commit 2b34e81

Please sign in to comment.