Skip to content

Commit 249cce6

Browse files
committed
Support dynamic filters via remote scanner
1 parent 6d25cde commit 249cce6

File tree

5 files changed

+110
-5
lines changed

5 files changed

+110
-5
lines changed

crates/storage-query-datafusion/src/remote_query_scanner_client.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use async_trait::async_trait;
1616
use datafusion::arrow::datatypes::SchemaRef;
1717
use datafusion::error::DataFusionError;
1818
use datafusion::execution::SendableRecordBatchStream;
19+
use datafusion::physical_expr_common::physical_expr::snapshot_generation;
1920
use datafusion::physical_plan::PhysicalExpr;
2021
use datafusion::physical_plan::stream::RecordBatchReceiverStream;
2122
use tracing::debug;
@@ -46,7 +47,10 @@ impl RemoteScanner {
4647
}
4748
}
4849

49-
async fn next_batch(&self) -> Result<RemoteQueryScannerNextResult, DataFusionError> {
50+
async fn next_batch(
51+
&self,
52+
next_predicate: Option<RemoteQueryScannerPredicate>,
53+
) -> Result<RemoteQueryScannerNextResult, DataFusionError> {
5054
let Some(ref connection) = self.connection else {
5155
return Err(DataFusionError::Internal(
5256
"connection used after forget()".to_string(),
@@ -67,6 +71,7 @@ impl RemoteScanner {
6771
.send_rpc(
6872
RemoteQueryScannerNext {
6973
scanner_id: self.scanner_id,
74+
next_predicate,
7075
},
7176
None,
7277
)
@@ -150,6 +155,12 @@ pub fn remote_scan_as_datafusion_stream(
150155
let tx = builder.tx();
151156

152157
let task = async move {
158+
// get a snapshot of the initial predicate
159+
let mut predicate_generation = predicate
160+
.as_ref()
161+
.map(|expr| snapshot_generation(expr))
162+
.unwrap_or(0);
163+
153164
let initial_predicate = match &predicate {
154165
Some(predicate) => Some(RemoteQueryScannerPredicate {
155166
serialized_physical_expression: encode_expr(predicate)?,
@@ -176,7 +187,26 @@ pub fn remote_scan_as_datafusion_stream(
176187
// loop while we have record_batch coming in
177188
//
178189
loop {
179-
let batch = match remote_scanner.next_batch().await {
190+
let next_predicate = if predicate_generation != 0 {
191+
// generation 0 means the predicate is static (or we never had one)
192+
let predicate = predicate
193+
.as_ref()
194+
.expect("must have a predicate if generation != 0");
195+
let current_predicate_generation = snapshot_generation(predicate);
196+
197+
if current_predicate_generation != predicate_generation {
198+
predicate_generation = current_predicate_generation;
199+
Some(RemoteQueryScannerPredicate {
200+
serialized_physical_expression: encode_expr(predicate)?,
201+
})
202+
} else {
203+
None
204+
}
205+
} else {
206+
None
207+
};
208+
209+
let batch = match remote_scanner.next_batch(next_predicate).await {
180210
Err(e) => {
181211
return Err(e);
182212
}

crates/storage-query-datafusion/src/remote_query_scanner_server.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ impl RemoteQueryScannerServer {
135135
// do that again here. If we do, we might end up dead-locking the map because we are holding a
136136
// reference into it (scanner).
137137
if let Err(mpsc::error::SendError(request)) =
138-
scanner.send(super::scanner_task::NextRequest { reciprocal })
138+
scanner.send(super::scanner_task::NextRequest {
139+
reciprocal,
140+
next_predicate: req.next_predicate,
141+
})
139142
{
140143
tracing::info!(
141144
"No such scanner {}. This could be an expired scanner due to a slow scan with no activity.",

crates/storage-query-datafusion/src/scanner_task.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use std::sync::{Arc, Weak};
1212
use std::time::Duration;
1313

1414
use anyhow::Context;
15-
use datafusion::execution::SendableRecordBatchStream;
15+
use datafusion::arrow::datatypes::SchemaRef;
16+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
1617
use datafusion::physical_plan::PhysicalExpr;
1718
use datafusion::prelude::SessionContext;
1819
use tokio::sync::mpsc;
@@ -23,7 +24,8 @@ use restate_core::network::{Oneshot, Reciprocal};
2324
use restate_core::{TaskCenter, TaskKind};
2425
use restate_types::GenerationalNodeId;
2526
use restate_types::net::remote_query_scanner::{
26-
RemoteQueryScannerNextResult, RemoteQueryScannerOpen, ScannerBatch, ScannerFailure, ScannerId,
27+
RemoteQueryScannerNextResult, RemoteQueryScannerOpen, RemoteQueryScannerPredicate,
28+
ScannerBatch, ScannerFailure, ScannerId,
2729
};
2830

2931
use crate::remote_query_scanner_manager::RemoteScannerManager;
@@ -34,6 +36,7 @@ const SCANNER_EXPIRATION: Duration = Duration::from_secs(60);
3436

3537
pub(crate) struct NextRequest {
3638
pub reciprocal: Reciprocal<Oneshot<RemoteQueryScannerNextResult>>,
39+
pub next_predicate: Option<RemoteQueryScannerPredicate>,
3740
}
3841

3942
pub(crate) type ScannerHandle = mpsc::UnboundedSender<NextRequest>;
@@ -45,6 +48,8 @@ pub(crate) struct ScannerTask {
4548
stream: SendableRecordBatchStream,
4649
rx: mpsc::UnboundedReceiver<NextRequest>,
4750
scanners: Weak<ScannerMap>,
51+
ctx: Arc<TaskContext>,
52+
schema: SchemaRef,
4853
predicate: Option<Arc<dyn PhysicalExpr>>,
4954
}
5055

@@ -88,6 +93,8 @@ impl ScannerTask {
8893
stream,
8994
rx,
9095
scanners: Arc::downgrade(scanners),
96+
ctx: SessionContext::new().task_ctx(),
97+
schema,
9198
predicate,
9299
};
93100

@@ -133,6 +140,21 @@ impl ScannerTask {
133140
}
134141
};
135142

143+
if let Some(next_predicate) = request.next_predicate {
144+
match decode_expr(
145+
&self.ctx,
146+
&self.schema,
147+
&next_predicate.serialized_physical_expression,
148+
) {
149+
// for now, we are not updating the predicate being passed to ScanPartition,
150+
// so we rely on the filtering below to apply dynamic filters
151+
Ok(next_predicate) => self.predicate = Some(next_predicate),
152+
Err(e) => {
153+
warn!("Failed to decode next predicate: {e}")
154+
}
155+
}
156+
}
157+
136158
let record_batch = loop {
137159
// connection/request has been closed, don't bother with driving the stream.
138160
// The scanner will be dropped because we want to make sure that we don't get supurious

crates/storage-query-datafusion/src/table_providers.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ use datafusion::execution::context::TaskContext;
2020
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
2121
use datafusion::physical_expr::EquivalenceProperties;
2222
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
23+
use datafusion::physical_plan::filter_pushdown::{
24+
FilterPushdownPhase, FilterPushdownPropagation, PushedDown,
25+
};
2326
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2427
use datafusion::physical_plan::{
2528
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, PlanProperties,
@@ -242,6 +245,10 @@ where
242245
) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
243246
let res = filters
244247
.iter()
248+
// if we set this to exact, we might be able to remove a FilterExec higher up the plan.
249+
// however, it means that fields we filter on won't end up in our projection, meaning we
250+
// have to manage a projected schema and a filter schema - defer this complexity for
251+
// future optimization.
245252
.map(|_| TableProviderFilterPushDown::Inexact)
246253
.collect();
247254

@@ -347,6 +354,46 @@ where
347354
sequential_scanners_stream,
348355
)))
349356
}
357+
358+
fn handle_child_pushdown_result(
359+
&self,
360+
phase: datafusion::physical_plan::filter_pushdown::FilterPushdownPhase,
361+
child_pushdown_result: datafusion::physical_plan::filter_pushdown::ChildPushdownResult,
362+
_config: &datafusion::config::ConfigOptions,
363+
) -> datafusion::error::Result<
364+
datafusion::physical_plan::filter_pushdown::FilterPushdownPropagation<
365+
Arc<dyn ExecutionPlan>,
366+
>,
367+
> {
368+
if !matches!(phase, FilterPushdownPhase::Post) {
369+
return Ok(FilterPushdownPropagation::if_all(child_pushdown_result));
370+
}
371+
372+
let filters = child_pushdown_result
373+
.parent_filters
374+
.iter()
375+
.map(|f| f.filter.clone());
376+
377+
let predicate = match &self.predicate {
378+
Some(predicate) => datafusion::physical_expr::conjunction(
379+
std::iter::once(predicate.clone()).chain(filters),
380+
),
381+
None => datafusion::physical_expr::conjunction(filters),
382+
};
383+
384+
let mut plan = self.clone();
385+
plan.predicate = Some(predicate);
386+
387+
Ok(FilterPushdownPropagation {
388+
// we report all filters as unsupported as we don't guarantee to apply them exactly as there can be a delay before new filters are used
389+
filters: child_pushdown_result
390+
.parent_filters
391+
.iter()
392+
.map(|_| PushedDown::No)
393+
.collect(),
394+
updated_node: Some(Arc::new(plan)),
395+
})
396+
}
350397
}
351398

352399
impl<T> DisplayAs for PartitionedExecutionPlan<T>

crates/types/src/net/remote_query_scanner.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ pub enum RemoteQueryScannerOpened {
9595
pub struct RemoteQueryScannerNext {
9696
#[bilrost(1)]
9797
pub scanner_id: ScannerId,
98+
#[bilrost(tag(2))]
99+
#[serde(default)]
100+
pub next_predicate: Option<RemoteQueryScannerPredicate>,
98101
}
99102

100103
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, bilrost::Message)]

0 commit comments

Comments
 (0)