Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions native-engine/auron-serde/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ message SortExecNode {

message FetchLimit {
// wrap into a message to make it optional
uint64 limit = 1;
uint32 limit = 1;
uint32 offset = 2;
}

message PhysicalRepartition {
Expand Down Expand Up @@ -705,7 +706,8 @@ enum AggMode {

message LimitExecNode {
PhysicalPlanNode input = 1;
uint64 limit = 2;
uint32 limit = 2;
uint32 offset = 3;
}

message FFIReaderExecNode {
Expand Down
22 changes: 15 additions & 7 deletions native-engine/auron-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,12 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
panic!("Failed to parse physical sort expressions: {}", e);
});

let fetch = sort.fetch_limit.as_ref();
let limit = fetch.map(|f| f.limit as usize);
let offset = fetch.map(|f| f.offset as usize).unwrap_or(0);

// always preserve partitioning
Ok(Arc::new(SortExec::new(
input,
exprs,
sort.fetch_limit.as_ref().map(|limit| limit.limit as usize),
)))
Ok(Arc::new(SortExec::new(input, exprs, limit, offset)))
}
PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => {
let input: Arc<dyn ExecutionPlan> = convert_box_required!(bhm.input)?;
Expand Down Expand Up @@ -501,7 +501,11 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
}
PhysicalPlanType::Limit(limit) => {
let input: Arc<dyn ExecutionPlan> = convert_box_required!(limit.input)?;
Ok(Arc::new(LimitExec::new(input, limit.limit)))
Ok(Arc::new(LimitExec::new(
input,
limit.limit as usize,
limit.offset as usize,
)))
}
PhysicalPlanType::FfiReader(ffi_reader) => {
let schema = Arc::new(convert_required!(ffi_reader.schema)?);
Expand All @@ -513,7 +517,11 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
}
PhysicalPlanType::CoalesceBatches(coalesce_batches) => {
let input: Arc<dyn ExecutionPlan> = convert_box_required!(coalesce_batches.input)?;
Ok(Arc::new(LimitExec::new(input, coalesce_batches.batch_size)))
Ok(Arc::new(LimitExec::new(
input,
coalesce_batches.batch_size as usize,
0,
)))
}
PhysicalPlanType::Expand(expand) => {
let schema = Arc::new(convert_required!(expand.schema)?);
Expand Down
99 changes: 87 additions & 12 deletions native-engine/datafusion-ext-plans/src/limit_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,18 @@ use crate::common::execution_context::ExecutionContext;
#[derive(Debug)]
pub struct LimitExec {
input: Arc<dyn ExecutionPlan>,
limit: u64,
limit: usize,
offset: usize,
pub metrics: ExecutionPlanMetricsSet,
props: OnceCell<PlanProperties>,
}

impl LimitExec {
pub fn new(input: Arc<dyn ExecutionPlan>, limit: u64) -> Self {
pub fn new(input: Arc<dyn ExecutionPlan>, limit: usize, offset: usize) -> Self {
Self {
input,
limit,
offset,
metrics: ExecutionPlanMetricsSet::new(),
props: OnceCell::new(),
}
Expand All @@ -59,7 +61,7 @@ impl LimitExec {

impl DisplayAs for LimitExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
write!(f, "LimitExec(limit={})", self.limit)
write!(f, "LimitExec(limit={},offset={})", self.limit, self.offset)
}
}

Expand Down Expand Up @@ -95,7 +97,11 @@ impl ExecutionPlan for LimitExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::new(children[0].clone(), self.limit)))
Ok(Arc::new(Self::new(
children[0].clone(),
self.limit,
self.offset,
)))
}

fn execute(
Expand All @@ -105,23 +111,27 @@ impl ExecutionPlan for LimitExec {
) -> Result<SendableRecordBatchStream> {
let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics);
let input = exec_ctx.execute_with_input_stats(&self.input)?;
execute_limit(input, self.limit, exec_ctx)
if self.offset == 0 {
execute_limit(input, self.limit, exec_ctx)
} else {
execute_limit_with_offset(input, self.limit, self.offset, exec_ctx)
}
}

fn statistics(&self) -> Result<Statistics> {
Statistics::with_fetch(
self.input.statistics()?,
self.schema(),
Some(self.limit as usize),
0,
Some(self.limit),
self.offset,
1,
)
}
}

fn execute_limit(
mut input: SendableRecordBatchStream,
limit: u64,
limit: usize,
exec_ctx: Arc<ExecutionContext>,
) -> Result<SendableRecordBatchStream> {
Ok(exec_ctx
Expand All @@ -131,11 +141,49 @@ fn execute_limit(
while remaining > 0
&& let Some(mut batch) = input.next().await.transpose()?
{
if remaining < batch.num_rows() as u64 {
batch = batch.slice(0, remaining as usize);
if remaining < batch.num_rows() {
batch = batch.slice(0, remaining);
remaining = 0;
} else {
remaining -= batch.num_rows();
}
exec_ctx.baseline_metrics().record_output(batch.num_rows());
sender.send(batch).await;
}
Ok(())
}))
}

fn execute_limit_with_offset(
mut input: SendableRecordBatchStream,
limit: usize,
offset: usize,
exec_ctx: Arc<ExecutionContext>,
) -> Result<SendableRecordBatchStream> {
Ok(exec_ctx
.clone()
.output_with_sender("Limit", move |sender| async move {
let mut skip = offset;
let mut remaining = limit - skip;
while remaining > 0
&& let Some(mut batch) = input.next().await.transpose()?
{
if skip > 0 {
let rows = batch.num_rows();
if skip >= rows {
skip -= rows;
continue;
}

batch = batch.slice(skip, rows - skip);
skip = 0;
}

if remaining < batch.num_rows() {
batch = batch.slice(0, remaining);
remaining = 0;
} else {
remaining -= batch.num_rows() as u64;
remaining -= batch.num_rows();
}
exec_ctx.baseline_metrics().record_output(batch.num_rows());
sender.send(batch).await;
Expand Down Expand Up @@ -203,7 +251,7 @@ mod test {
("b", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
);
let limit_exec = LimitExec::new(input, 2_u64);
let limit_exec = LimitExec::new(input, 2, 0);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let output = limit_exec.execute(0, task_ctx).unwrap();
Expand All @@ -222,4 +270,31 @@ mod test {
assert_eq!(row_count, Precision::Exact(2));
Ok(())
}

#[tokio::test]
async fn test_limit_with_offset() -> Result<()> {
let input = build_table(
("a", &vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
("b", &vec![9, 8, 7, 6, 5, 4, 3, 2, 1, 0]),
("c", &vec![5, 6, 7, 8, 9, 0, 1, 2, 3, 4]),
);
let limit_exec = LimitExec::new(input, 7, 5);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let output = limit_exec.execute(0, task_ctx).unwrap();
let batches = common::collect(output).await?;
let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();

let expected = vec![
"+---+---+---+",
"| a | b | c |",
"+---+---+---+",
"| 5 | 4 | 0 |",
"| 6 | 3 | 1 |",
"+---+---+---+",
];
assert_batches_eq!(expected, &batches);
assert_eq!(row_count, 2);
Ok(())
}
}
Loading
Loading