diff --git a/opteryx/exceptions.py b/opteryx/exceptions.py index 19d51834..570464bc 100644 --- a/opteryx/exceptions.py +++ b/opteryx/exceptions.py @@ -304,6 +304,7 @@ def __init__( right_column: Optional[str] = None, left_node: Optional[Any] = None, right_node: Optional[Any] = None, + message: Optional[str] = None, ): def _format_col(_type, _node, _name): if _node.node_type == 42: @@ -317,7 +318,9 @@ def _format_col(_type, _node, _name): self.column = column self.left_column = left_column self.right_column = right_column - if self.column: + if message: + super().__init__(message) + elif self.column: super().__init__( f"Incompatible types for column '{column}': {left_type} and {right_type}" ) diff --git a/opteryx/functions/__init__.py b/opteryx/functions/__init__.py index de093890..eecd8b0d 100644 --- a/opteryx/functions/__init__.py +++ b/opteryx/functions/__init__.py @@ -529,8 +529,8 @@ def apply_function(function: str = None, *parameters): if null_positions.all(): return numpy.array([None] * morsel_size) - if null_positions.any(): - # if we have nulls and both columns are numpy arrays, we can speed things + if null_positions.any() and all(isinstance(arr, numpy.ndarray) for arr in parameters): + # if we have nulls and the value array is a numpy arrays, we can speed things # up by removing the nulls from the calculations, we add the rows back in # later valid_positions = ~null_positions diff --git a/opteryx/functions/other_functions.py b/opteryx/functions/other_functions.py index eda4da41..6822f659 100644 --- a/opteryx/functions/other_functions.py +++ b/opteryx/functions/other_functions.py @@ -18,6 +18,7 @@ import simdjson from pyarrow import compute +from opteryx.exceptions import IncompatibleTypesError from opteryx.exceptions import SqlError @@ -150,13 +151,34 @@ def null_if(col1, col2): An array where elements from col1 are replaced with None if they match the corresponding elements in col2. """ if isinstance(col1, pyarrow.Array): - values = values.to_numpy(False) + col1 = col1.to_numpy(False) if isinstance(col1, list): - values = numpy.array(values) + col1 = col1.array(col1) if isinstance(col2, pyarrow.Array): - values = values.to_numpy(False) + col2 = col2.to_numpy(False) if isinstance(col2, list): - values = numpy.array(values) + col2 = col2.array(col2) + + from orso.types import PYTHON_TO_ORSO_MAP + from orso.types import OrsoTypes + + def get_first_non_null_type(array): + for item in array: + if item is not None: + return PYTHON_TO_ORSO_MAP.get(type(item), OrsoTypes._MISSING_TYPE) + return OrsoTypes.NULL + + col1_type = get_first_non_null_type(col1.tolist()) + col2_type = get_first_non_null_type(col2.tolist()) + + if col1_type != col2_type: + print(col1_type, col2_type) + + raise IncompatibleTypesError( + left_type=col1_type, + right_type=col2_type, + message=f"`NULLIF` called with input arrays of different types, {col1_type} and {col2_type}.", + ) # Create a mask where elements in col1 are equal to col2 mask = col1 == col2 diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 08f6ed5b..05811599 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -2185,6 +2185,10 @@ ("SELECT DISTINCT l FROM (SELECT split('a b c d e f g h i j', ' ') as letters) as plet CROSS JOIN UNNEST (letters) as l", 10, 1, None), # 2112 ("SELECT id FROM $planets WHERE surface_pressure / surface_pressure is null", 5, 1, None), + #2144 + ("SELECT town, LENGTH(NULLIF(town, 'Inglewood')) FROM (SELECT birth_place->'town' AS town FROM $astronauts) AS T", 357, 2, None), + ("SELECT town, LENGTH(NULLIF(town, b'Inglewood')) FROM (SELECT birth_place->>'town' AS town FROM $astronauts) AS T", 357, 2, None), + ("SELECT town, LENGTH(NULLIF(town, 'Inglewood')) FROM (SELECT birth_place->>'town' AS town FROM $astronauts) AS T", None, None, IncompatibleTypesError), ] # fmt:on