Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2066 #2067

Merged
merged 4 commits into from
Oct 14, 2024
Merged

#2066 #2067

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 827
__build__ = 830

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions opteryx/compiled/list_ops/cython_list_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ cpdef cnp.ndarray[cnp.uint8_t, ndim=1] list_contains_any(cnp.ndarray array, cnp.
Cython optimized version that works with object arrays.
"""
cdef set items_set = set(items[0])
cdef Py_ssize_t size = array.size
cdef Py_ssize_t size = array.shape[0]
cdef cnp.ndarray[cnp.uint8_t, ndim=1] res = numpy.zeros(size, dtype=numpy.uint8)
cdef Py_ssize_t i
cdef cnp.ndarray test_set

for i in range(size):
test_set = array[i]
if test_set is not None:
if not(test_set is None or test_set.shape[0] == 0):
for el in test_set:
if el in items_set:
res[i] = True
res[i] = 1
break
return res
4 changes: 2 additions & 2 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ def sleep(x):
"ARRAY_CONTAINS": _iterate_double_parameter(other_functions.list_contains),
"LIST_CONTAINS_ANY": list_contains_any,
"ARRAY_CONTAINS_ANY": list_contains_any,
"LIST_CONTAINS_ALL": _iterate_double_parameter(other_functions.list_contains_all),
"ARRAY_CONTAINS_ALL": _iterate_double_parameter(other_functions.list_contains_all),
"LIST_CONTAINS_ALL": other_functions.list_contains_all,
"ARRAY_CONTAINS_ALL": other_functions.list_contains_all,
"SEARCH": other_functions.search,
"COALESCE": _coalesce,
"IFNULL": other_functions.if_null,
Expand Down
4 changes: 3 additions & 1 deletion opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def list_contains_all(array, items):
"""
if array is None:
return False
return set(array).issuperset(items)
required_items = set(items[0]) # Convert items[0] to a set once for efficient lookups
return [None if a is None else set(a).issuperset(required_items) for a in array]


def search(array, item, ignore_case: Optional[List[bool]] = None):
Expand Down Expand Up @@ -198,6 +199,7 @@ def cosine_similarity(
def jsonb_object_keys(arr):
if len(arr) == 0:
return []
result = []
if isinstance(arr[0], dict):
result = [[str(key) for key in row] for row in arr]
if isinstance(arr[0], (str, bytes)):
Expand Down
4 changes: 2 additions & 2 deletions opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def _inner_evaluate(root: Node, table: Table):
# if it's a literal value, return it once for every value in the table
literal_type = root.type
if literal_type == OrsoTypes.ARRAY:
# this isn't as fast as .full - but lists and strings are problematic
return numpy.array([root.value] * table.num_rows)
# creating ARRAY columns is expensive, so we don't create one full length
return numpy.array([root.value])
if literal_type == OrsoTypes.VARCHAR:
return numpy.array([root.value] * table.num_rows, dtype=numpy.unicode_)
if literal_type == OrsoTypes.BLOB:
Expand Down
1 change: 1 addition & 0 deletions opteryx/managers/expression/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def format_expression(root, qualify: bool = False):
"Arrow": "->",
"LongArrow": "->>",
"AtQuestion": "@?",
"AtArrow": "@>",
}
return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}"
if node_type == NodeType.EXPRESSION_LIST:
Expand Down
6 changes: 6 additions & 0 deletions opteryx/managers/expression/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def filter_operations(arr, left_type, operator, value, right_type):
"AnyOpLtEq",
"AllOpEq",
"AllOpNotEq",
"AtArrow",
):
# compressing ARRAY columns is VERY SLOW
morsel_size = len(arr)
Expand Down Expand Up @@ -183,4 +184,9 @@ def _inner_filter_operations(arr, operator, value):
type=pyarrow.bool_(), # type:ignore
)

if operator == "AtArrow":
from opteryx.compiled.list_ops import list_contains_any

return list_contains_any(arr, value)

raise NotImplementedError(f"Operator {operator} is not implemented!") # pragma: no cover
7 changes: 6 additions & 1 deletion opteryx/planner/ast_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ def rewrite_json_accessors(node: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(node, dict):
return node

if "BinaryOp" in node and node["BinaryOp"].get("op") in ("Arrow", "LongArrow", "AtQuestion"):
if "BinaryOp" in node and node["BinaryOp"].get("op") in (
"Arrow",
"LongArrow",
"AtQuestion",
"AtArrow",
):
document = node["BinaryOp"]["left"]
accessor = node["BinaryOp"]["op"]
right_node = node["BinaryOp"]["right"]
Expand Down
1 change: 1 addition & 0 deletions opteryx/planner/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def get_mismatched_condition_column_types(node: Node, relaxed: bool = False) ->
"Arrow",
"LongArrow",
"AtQuestion",
"AtArrow",
) or node.value.startswith(("AllOp", "AnyOp")):
return None # Some ops are meant to have different types
left_type = node.left.schema_column.type if node.left.schema_column else None
Expand Down
1 change: 1 addition & 0 deletions opteryx/planner/binder/operator_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class OperatorMapType(NamedTuple):

# fmt: off
OPERATOR_MAP: Dict[Tuple[OrsoTypes, OrsoTypes, str], OperatorMapType] = {
(OrsoTypes.ARRAY, OrsoTypes.ARRAY, "AtArrow"): OperatorMapType(OrsoTypes.BOOLEAN, None, 100.0),
(OrsoTypes.BLOB, OrsoTypes.VARCHAR, "Eq"): OperatorMapType(OrsoTypes.BOOLEAN, None, 100.0),
(OrsoTypes.BLOB, OrsoTypes.VARCHAR, "NotEq"): OperatorMapType(OrsoTypes.BOOLEAN, None, 100.0),
(OrsoTypes.BLOB, OrsoTypes.VARCHAR, "Gt"): OperatorMapType(OrsoTypes.BOOLEAN, None, 100.0),
Expand Down
4 changes: 1 addition & 3 deletions tests/misc/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def test_literals(node_type, value_type, value):
assert values.dtype == ORSO_TO_NUMPY_MAP[value_type], values
else:
assert type(values[0]) == numpy.ndarray, values[0]
assert len(values) == planets.num_rows

print(values[0])
# assert len(values) == planets.num_rows, f"{len(values)} != {planets.num_rows}"


def test_logical_expressions():
Expand Down
5 changes: 5 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,11 @@
("SELECT missions FROM $astronauts WHERE LIST_CONTAINS_ANY(missions, ('Apollo 8', 'Apollo 13'))", 5, 1, None),
("SELECT missions FROM $astronauts WHERE LIST_CONTAINS_ALL(missions, ('Apollo 8', 'Gemini 7'))", 2, 1, None),
("SELECT missions FROM $astronauts WHERE LIST_CONTAINS_ALL(missions, ('Gemini 7', 'Apollo 8'))", 2, 1, None),
("SELECT missions FROM $astronauts WHERE ARRAY_CONTAINS(missions, 'Apollo 8')", 3, 1, None),
("SELECT missions FROM $astronauts WHERE ARRAY_CONTAINS_ANY(missions, ('Apollo 8', 'Apollo 13'))", 5, 1, None),
("SELECT missions FROM $astronauts WHERE ARRAY_CONTAINS_ALL(missions, ('Apollo 8', 'Gemini 7'))", 2, 1, None),
("SELECT missions FROM $astronauts WHERE ARRAY_CONTAINS_ALL(missions, ('Gemini 7', 'Apollo 8'))", 2, 1, None),
("SELECT missions FROM $astronauts WHERE missions @> ('Apollo 8', 'Apollo 13')", 5, 1, None),

("SELECT * FROM $astronauts WHERE 'Apollo 11' = any(missions)", 3, 19, None),
("SELECT * FROM $astronauts WHERE 'X' > any(alma_mater)", 3, 19, None),
Expand Down