Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 138 additions & 21 deletions datafusion/physical-plan/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec {
fn statistics(&self) -> Result<Statistics> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for noticing the error and fixing it. I think the new logic here is quite reasonable.

let input_stats = self.input.statistics()?;
let skip = self.skip;
// the maximum row number needs to be fetched
let max_row_num = self
.fetch
.map(|fetch| {
if fetch >= usize::MAX - skip {
usize::MAX
} else {
fetch + skip
}
})
.unwrap_or(usize::MAX);
let col_stats = Statistics::unknown_column(&self.schema());
let fetch = self.fetch.unwrap_or(usize::MAX);

let fetched_row_number_stats = Statistics {
num_rows: Precision::Exact(max_row_num),
let mut fetched_row_number_stats = Statistics {
num_rows: Precision::Exact(fetch),
column_statistics: col_stats.clone(),
total_byte_size: Precision::Absent,
};
Expand All @@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec {
} => {
if nr <= skip {
// if all input data will be skipped, return 0
Statistics {
let mut skip_all_rows_stats = Statistics {
num_rows: Precision::Exact(0),
column_statistics: col_stats,
total_byte_size: Precision::Absent,
};
if !input_stats.num_rows.is_exact().unwrap_or(false) {
// The input stats are inexact, so the output stats must be too.
skip_all_rows_stats = skip_all_rows_stats.into_inexact();
}
} else if nr <= max_row_num {
// if the input does not reach the "fetch" globally, return input stats
skip_all_rows_stats
} else if nr <= fetch && self.skip == 0 {
// if the input does not reach the "fetch" globally, and "skip" is zero
// (meaning the input and output are identical), return input stats.
// Can input_stats still be used, but adjusted, in the "skip != 0" case?
input_stats
} else if nr - skip <= fetch {
// after "skip" input rows are skipped, the remaining rows are less than or equal to the
// "fetch" values, so `num_rows` must equal the remaining rows
let remaining_rows: usize = nr - skip;
let mut skip_some_rows_stats = Statistics {
num_rows: Precision::Exact(remaining_rows),
column_statistics: col_stats.clone(),
total_byte_size: Precision::Absent,
};
if !input_stats.num_rows.is_exact().unwrap_or(false) {
// The input stats are inexact, so the output stats must be too.
skip_some_rows_stats = skip_some_rows_stats.into_inexact();
}
skip_some_rows_stats
} else {
// if the input is greater than the "fetch", the num_row will be the "fetch",
// if the input is greater than "fetch+skip", the num_rows will be the "fetch",
// but we won't be able to predict the other statistics
if !input_stats.num_rows.is_exact().unwrap_or(false)
|| self.fetch.is_none()
{
// If the input stats are inexact, the output stats must be too.
// If the fetch value is `usize::MAX` because no LIMIT was specified,
// we also can't represent it as an exact value.
fetched_row_number_stats =
fetched_row_number_stats.into_inexact();
}
fetched_row_number_stats
}
}
_ => {
// the result output row number will always be no greater than the limit number
fetched_row_number_stats
// The result output `num_rows` will always be no greater than the limit number.
// Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
// as the actual `num_rows` may be far away from the `fetch` value?
fetched_row_number_stats.into_inexact()
}
};
Ok(stats)
Expand Down Expand Up @@ -552,7 +574,10 @@ mod tests {
use crate::common::collect;
use crate::{common, test};

use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use arrow_schema::Schema;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::PhysicalExpr;

#[tokio::test]
async fn limit() -> Result<()> {
Expand Down Expand Up @@ -712,7 +737,7 @@ mod tests {
}

#[tokio::test]
async fn skip_3_fetch_10() -> Result<()> {
async fn skip_3_fetch_10_stats() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(3, Some(10)).await?;
assert_eq!(row_count, 10);
Expand Down Expand Up @@ -748,7 +773,58 @@ mod tests {
assert_eq!(row_count, Precision::Exact(10));

let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(15));
assert_eq!(row_count, Precision::Exact(10));

let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(0));

let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
assert_eq!(row_count, Precision::Exact(1));

let row_count = row_number_statistics_for_global_limit(398, None).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count =
row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Exact(400));

let row_count =
row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Exact(2));

let row_count =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let row_count =
// test inexact input statistics
let row_count =

This threw me for quite a while while trying to figure out why a fetch of 10 resulted in inexact statistics. Now that I see the difference is the input statistics are inexact it makes much more sense

row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(10));

let row_count =
row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(10));

let row_count =
row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(0));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
assert_eq!(row_count, Precision::Inexact(2));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
assert_eq!(row_count, Precision::Inexact(1));

let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
assert_eq!(row_count, Precision::Inexact(2));

let row_count =
row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Inexact(400));

let row_count =
row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
assert_eq!(row_count, Precision::Inexact(2));

Ok(())
}
Expand Down Expand Up @@ -776,6 +852,47 @@ mod tests {
Ok(offset.statistics()?.num_rows)
}

pub fn build_group_by(
input_schema: &SchemaRef,
columns: Vec<String>,
) -> PhysicalGroupBy {
let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
for column in columns.iter() {
group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
}
PhysicalGroupBy::new_single(group_by_expr.clone())
}

async fn row_number_inexact_statistics_for_global_limit(
skip: usize,
fetch: Option<usize>,
) -> Result<Precision<usize>> {
let num_partitions = 4;
let csv = test::scan_partitioned(num_partitions);

assert_eq!(csv.output_partitioning().partition_count(), num_partitions);

// Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
let agg = AggregateExec::try_new(
AggregateMode::Final,
build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
vec![],
vec![None],
vec![None],
csv.clone(),
csv.schema().clone(),
)?;
let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);

let offset = GlobalLimitExec::new(
Arc::new(CoalescePartitionsExec::new(agg_exec)),
skip,
fetch,
);

Ok(offset.statistics()?.num_rows)
}

async fn row_number_statistics_for_local_limit(
num_partitions: usize,
fetch: usize,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ query TT
EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10;
----
physical_plan
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent]
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old version is definitely misleading, nice improvement 👍

--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent]

# Parquet scan with statistics collected
Expand Down
85 changes: 85 additions & 0 deletions datafusion/sqllogictest/test_files/limit.slt
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,91 @@ query T
SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101
----

#
# global limit statistics test
#

statement ok
CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10);

# The aggregate does not need to be computed because the input statistics are exact and
# the number of rows is less than the skip value (OFFSET).
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=11, fetch=3
----TableScan: t1 projection=[], fetch=14
physical_plan
ProjectionExec: expr=[0 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
----
0

# The aggregate does not need to be computed because the input statistics are exact and
# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET).
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=8, fetch=3
----TableScan: t1 projection=[], fetch=11
physical_plan
ProjectionExec: expr=[2 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
2

# The aggregate does not need to be computed because the input statistics are exact and
# an OFFSET, but no LIMIT, is specified.
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=8, fetch=None
----TableScan: t1 projection=[]
physical_plan
ProjectionExec: expr=[2 as COUNT(*)]
--EmptyExec: produce_one_row=true

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
----
2

# The aggregate needs to be computed because the input statistics are inexact.
query TT
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
----
logical_plan
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
--Limit: skip=6, fetch=3
----Filter: t1.a > Int32(3)
------TableScan: t1 projection=[a]
physical_plan
AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)]
--CoalescePartitionsExec
----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)]
------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------GlobalLimitExec: skip=6, fetch=3
----------CoalesceBatchesExec: target_batch_size=8192
------------FilterExec: a@0 > 3
--------------MemoryExec: partitions=1, partition_sizes=[1]

query I
SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
----
1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

########
# Clean up after the test
########
Expand Down
12 changes: 8 additions & 4 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2010,10 +2010,14 @@ Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1
--------TableScan: aggregate_test_100 projection=[c13]
physical_plan
ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1]
--AggregateExec: mode=Single, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
----GlobalLimitExec: skip=0, fetch=1
------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
----CoalescePartitionsExec
------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
----------GlobalLimitExec: skip=0, fetch=1
------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true


query ?
SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1)
Expand Down