Skip to content

Commit

Permalink
Merge pull request #2156 from mabel-dev/#2147
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Dec 27, 2024
2 parents 065b389 + feb6e53 commit e56c694
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 93 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 921
__build__ = 922

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
3 changes: 3 additions & 0 deletions opteryx/compiled/structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .hash_table import HashSet
from .hash_table import HashTable
from .hash_table import anti_join
from .hash_table import distinct
from .hash_table import filter_join_set
from .hash_table import list_distinct
from .hash_table import semi_join
from .memory_pool import MemoryPool
from .node import Node
103 changes: 94 additions & 9 deletions opteryx/compiled/structures/hash_table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ cdef inline object recast_column(column):



@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):
"""
Perform a distinct operation on the given table using an external HashSet.
Expand Down Expand Up @@ -142,9 +140,8 @@ cpdef tuple distinct(table, HashSet seen_hashes=None, list columns=None):

return keep, seen_hashes

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void compute_float_hashes(cnp.ndarray[cnp.float64_t] data, int64_t null_hash, cnp.ndarray[int64_t] hashes):
cdef void compute_float_hashes(cnp.ndarray[cnp.float64_t] data, int64_t null_hash, int64_t[:] hashes):
cdef Py_ssize_t i, n = data.shape[0]
cdef cnp.float64_t value
for i in range(n):
Expand All @@ -154,9 +151,8 @@ cdef void compute_float_hashes(cnp.ndarray[cnp.float64_t] data, int64_t null_has
else:
hashes[i] = hash(value)

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void compute_int_hashes(cnp.ndarray[cnp.int64_t] data, int64_t null_hash, cnp.ndarray[int64_t] hashes):
cdef void compute_int_hashes(cnp.ndarray[cnp.int64_t] data, int64_t null_hash, int64_t[:] hashes):
cdef Py_ssize_t i, n = data.shape[0]
cdef cnp.int64_t value
for i in range(n):
Expand All @@ -168,9 +164,8 @@ cdef void compute_int_hashes(cnp.ndarray[cnp.int64_t] data, int64_t null_hash, c
else:
hashes[i] = value # Hash of int is the int itself in Python 3

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void compute_object_hashes(cnp.ndarray data, int64_t null_hash, cnp.ndarray[int64_t] hashes):
cdef void compute_object_hashes(cnp.ndarray data, int64_t null_hash, int64_t[:] hashes):
cdef Py_ssize_t i, n = data.shape[0]
cdef object value
for i in range(n):
Expand Down Expand Up @@ -209,7 +204,6 @@ cpdef tuple list_distinct(cnp.ndarray values, cnp.int32_t[::1] indices, HashSet



@cython.boundscheck(False)
@cython.wraparound(False)
cpdef HashTable hash_join_map(relation, list join_columns):
"""
Expand Down Expand Up @@ -276,3 +270,94 @@ cpdef HashTable hash_join_map(relation, list join_columns):
ht.insert(hash_value, non_null_indices[i])

return ht


cpdef filter_join_set(relation, list join_columns, HashSet seen_hashes):

cdef int64_t num_columns = len(join_columns)

if seen_hashes is None:
seen_hashes = HashSet()

# Memory view for the values array (for the join columns)
cdef object[:, ::1] values_array = numpy.array(list(relation.select(join_columns).drop_null().itercolumns()), dtype=object)

cdef int64_t hash_value, i

if num_columns == 1:
col = values_array[0, :]
for i in range(len(col)):
hash_value = <int64_t>hash(col[i])
seen_hashes.insert(hash_value)
else:
for i in range(values_array.shape[1]):
# Combine the hashes of each value in the row
hash_value = 0
for value in values_array[:, i]:
hash_value = <int64_t>(hash_value * 31 + hash(value))
seen_hashes.insert(hash_value)

return seen_hashes

cpdef anti_join(relation, list join_columns, HashSet seen_hashes):
cdef int64_t num_columns = len(join_columns)
cdef int64_t num_rows = relation.shape[0]
cdef int64_t hash_value, i
cdef cnp.ndarray[int64_t, ndim=1] index_buffer = numpy.empty(num_rows, dtype=numpy.int64)
cdef int64_t idx_count = 0

cdef object[:, ::1] values_array = numpy.array(list(relation.select(join_columns).drop_null().itercolumns()), dtype=object)

if num_columns == 1:
col = values_array[0, :]
for i in range(len(col)):
hash_value = <int64_t>hash(col[i])
if not seen_hashes.contains(hash_value):
index_buffer[idx_count] = i
idx_count += 1
else:
for i in range(values_array.shape[1]):
# Combine the hashes of each value in the row
hash_value = 0
for value in values_array[:, i]:
hash_value = <int64_t>(hash_value * 31 + hash(value))
if not seen_hashes.contains(hash_value):
index_buffer[idx_count] = i
idx_count += 1

if idx_count > 0:
return relation.take(index_buffer[:idx_count])
else:
return relation.slice(0, 0)


cpdef semi_join(relation, list join_columns, HashSet seen_hashes):
cdef int64_t num_columns = len(join_columns)
cdef int64_t num_rows = relation.shape[0]
cdef int64_t hash_value, i
cdef cnp.ndarray[int64_t, ndim=1] index_buffer = numpy.empty(num_rows, dtype=numpy.int64)
cdef int64_t idx_count = 0

cdef object[:, ::1] values_array = numpy.array(list(relation.select(join_columns).drop_null().itercolumns()), dtype=object)

if num_columns == 1:
col = values_array[0, :]
for i in range(len(col)):
hash_value = <int64_t>hash(col[i])
if seen_hashes.contains(hash_value):
index_buffer[idx_count] = i
idx_count += 1
else:
for i in range(values_array.shape[1]):
# Combine the hashes of each value in the row
hash_value = 0
for value in values_array[:, i]:
hash_value = <int64_t>(hash_value * 31 + hash(value))
if seen_hashes.contains(hash_value):
index_buffer[idx_count] = i
idx_count += 1

if idx_count > 0:
return relation.take(index_buffer[:idx_count])
else:
return relation.slice(0, 0)
96 changes: 13 additions & 83 deletions opteryx/operators/filter_join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,78 +13,17 @@
presence or absence of matching rows in the right table.
"""

from typing import List
from typing import Set

import pyarrow

from opteryx import EOS
from opteryx.compiled.structures import anti_join
from opteryx.compiled.structures import filter_join_set
from opteryx.compiled.structures import semi_join
from opteryx.models import QueryProperties

from . import JoinNode


def left_anti_join(left_relation, left_columns: List[str], right_hash_set: Set[str]):
"""
Perform a LEFT ANTI JOIN.
This implementation ensures that all rows from the left table are included in the result set,
where there are no matching rows in the right table based on the join columns.
Parameters:
left_relation (pyarrow.Table): The left pyarrow.Table to join.
left_columns (list of str): Column names from the left table to join on.
right_hash_set (set of tuple): A set of tuples representing the hashed values of the right table's join columns.
Returns:
A pyarrow.Table containing the result of the LEFT ANTI JOIN operation.
"""

left_indexes = []
left_values = left_relation.select(left_columns).drop_null().itercolumns()
for i, value_tuple in enumerate(map(hash, zip(*left_values))):
if (
value_tuple not in right_hash_set
): # Only include left rows that have no match in the right table
left_indexes.append(i)

# Filter the left_chunk based on the anti join condition
if left_indexes:
return left_relation.take(left_indexes)
else:
return left_relation.slice(0, 0)


def left_semi_join(left_relation, left_columns: List[str], right_hash_set: Set[str]):
"""
Perform a LEFT SEMI JOIN.
This implementation ensures that all rows from the left table that have a matching row in the right table
based on the join columns are included in the result set.
Parameters:
left_relation (pyarrow.Table): The left pyarrow.Table to join.
left_columns (list of str): Column names from the left table to join on.
right_hash_set (set of tuple): A set of tuples representing the hashed values of the right table's join columns.
Returns:
A pyarrow.Table containing the result of the LEFT ANTI JOIN operation.
"""
left_indexes = []
left_values = left_relation.select(left_columns).drop_null().itercolumns()
for i, value_tuple in enumerate(map(hash, zip(*left_values))):
if (
value_tuple in right_hash_set
): # Only include left rows that have a match in the right table
left_indexes.append(i)

# Filter the left_chunk based on the semi join condition
if left_indexes:
return left_relation.take(left_indexes)
else:
return left_relation.slice(0, 0)


class FilterJoinNode(JoinNode):
def __init__(self, properties: QueryProperties, **parameters):
JoinNode.__init__(self, properties=properties, **parameters)
Expand All @@ -98,16 +37,15 @@ def __init__(self, properties: QueryProperties, **parameters):
self.right_columns = parameters.get("right_columns")
self.right_readers = parameters.get("right_readers")

self.right_buffer = []
self.right_hash_set = set()
self.right_hash_set = None

@classmethod
def from_json(cls, json_obj: str) -> "BasePlanNode": # pragma: no cover
raise NotImplementedError()

@property
def name(self): # pragma: no cover
return self.join_type
return self.join_type.replace(" ", "_")

@property
def config(self) -> str: # pragma: no cover
Expand All @@ -126,24 +64,16 @@ def execute(self, morsel: pyarrow.Table, join_leg: str) -> pyarrow.Table:
else:
join_provider = providers.get(self.join_type)
yield join_provider(
left_relation=morsel,
left_columns=self.left_columns,
right_hash_set=self.right_hash_set,
)
if join_leg == "right":
if morsel == EOS:
right_relation = pyarrow.concat_tables(self.right_buffer, promote_options="none")
self.right_buffer.clear()
non_null_right_values = (
right_relation.select(self.right_columns).drop_null().itercolumns()
relation=morsel,
join_columns=self.left_columns,
seen_hashes=self.right_hash_set,
)
self.right_hash_set = set(map(hash, zip(*non_null_right_values)))
else:
self.right_buffer.append(morsel)
yield None
if join_leg == "right" and morsel != EOS:
self.right_hash_set = filter_join_set(morsel, self.right_columns, self.right_hash_set)
yield None


providers = {
"left anti": left_anti_join,
"left semi": left_semi_join,
"left anti": anti_join,
"left semi": semi_join,
}

0 comments on commit e56c694

Please sign in to comment.