Description
Is your feature request related to a problem or challenge?
Suppose you are building a distributed query engine on top of DataFusion and you want to run a query like
SELECT facts.fact_value, data.id, data.fact_id
FROM facts OUTER JOIN data
ON data.fact_id = fact.id
where facts
is a small "fact" table and data
is some HUGE table (many, many TB lets assume).
The optimal way to do this in a single node execution is probably using CollectLeft
since fact
is small, but this doesn't really work in a distributed join because CollectLeft
joins rely on in-memory state.
The correct way to do this in a distributed execution is to use a partitioned join and repartition data
but this is a problem because data
is huge and the repartition would require shuffling a potentially massive amount of data.
Describe the solution you'd like
Add a "hook" in HashJoinExec
that would allow shared state to be managed in a distributed execution in a user-defined way.
This might look something like
pub struct DistributedJoinState {
state_impl: Arc<dyn DistributedJoinStateImpl>,
}
impl DistributedJoinState {
pub fn new(state_impl: Arc<dyn DistributedJoinStateImpl>) -> Self {
Self { state_impl }
}
}
pub enum DistributedProbeState {
// Probes are still running in other distributed tasks
Continue,
// Current task is last probe running so emit unmatched rows
// if required by join type
Ready(BooleanBufferBuilder)
}
pub trait DistributedJoinStateImpl: Send + Sync + 'static {
/// Poll the distributed state with the current task's build side visited bit mask
fn poll_probe_completed(&self, mask: &BooleanBufferBuilder, cx: &mut Context<'_>) -> Poll<Result<DistributedProbeState>>;
}
type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;
/// HashTable and input data for the left (build side) of a join
struct JoinLeftData {
/// The hash table with indices into `batch`
hash_map: JoinHashMap,
/// The input rows for the build side
batch: RecordBatch,
/// Shared bitmap builder for visited left indices
visited_indices_bitmap: Mutex<BooleanBufferBuilder>,
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
distributed_state: Option<Arc<DistributedJoinState>>,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
#[allow(dead_code)]
reservation: MemoryReservation,
}
That is, JoinLeftData
can have an optional DistributedJoinState
that can be passed in through the TaskContext
during execution. If not provided then everything works exactly as it does now. But if it is provided, then HashJoinStream
can poll the distributed state when it's last (local) probe task completes and, if its the last global probe task, emit the unmatched rows based on the global bit mask.
Describe alternatives you've considered
Do nothing and rely on only hash partitioned joins for distributed use cases
Additional context
This sort of goes against the idea that DataFusion itself is not a library for distributed query execution, but given that many use cases of DF are in fact for distributed execution it might make sense to provide hooks for that directly in DF as long as they don't add any meaningful overhead to the single-node execution model.
If that is not the way we want to go then totally fine, just raising the question :)