-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Fix incorrect results in COUNT(*) queries with LIMIT #8049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec { | |||||||
| fn statistics(&self) -> Result<Statistics> { | ||||||||
| let input_stats = self.input.statistics()?; | ||||||||
| let skip = self.skip; | ||||||||
| // the maximum row number needs to be fetched | ||||||||
| let max_row_num = self | ||||||||
| .fetch | ||||||||
| .map(|fetch| { | ||||||||
| if fetch >= usize::MAX - skip { | ||||||||
| usize::MAX | ||||||||
| } else { | ||||||||
| fetch + skip | ||||||||
| } | ||||||||
| }) | ||||||||
| .unwrap_or(usize::MAX); | ||||||||
| let col_stats = Statistics::unknown_column(&self.schema()); | ||||||||
| let fetch = self.fetch.unwrap_or(usize::MAX); | ||||||||
|
|
||||||||
| let fetched_row_number_stats = Statistics { | ||||||||
| num_rows: Precision::Exact(max_row_num), | ||||||||
| let mut fetched_row_number_stats = Statistics { | ||||||||
| num_rows: Precision::Exact(fetch), | ||||||||
| column_statistics: col_stats.clone(), | ||||||||
| total_byte_size: Precision::Absent, | ||||||||
| }; | ||||||||
|
|
@@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec { | |||||||
| } => { | ||||||||
| if nr <= skip { | ||||||||
| // if all input data will be skipped, return 0 | ||||||||
| Statistics { | ||||||||
| let mut skip_all_rows_stats = Statistics { | ||||||||
| num_rows: Precision::Exact(0), | ||||||||
| column_statistics: col_stats, | ||||||||
| total_byte_size: Precision::Absent, | ||||||||
| }; | ||||||||
| if !input_stats.num_rows.is_exact().unwrap_or(false) { | ||||||||
| // The input stats are inexact, so the output stats must be too. | ||||||||
| skip_all_rows_stats = skip_all_rows_stats.into_inexact(); | ||||||||
| } | ||||||||
| } else if nr <= max_row_num { | ||||||||
| // if the input does not reach the "fetch" globally, return input stats | ||||||||
| skip_all_rows_stats | ||||||||
| } else if nr <= fetch && self.skip == 0 { | ||||||||
| // if the input does not reach the "fetch" globally, and "skip" is zero | ||||||||
| // (meaning the input and output are identical), return input stats. | ||||||||
| // Can input_stats still be used, but adjusted, in the "skip != 0" case? | ||||||||
| input_stats | ||||||||
| } else if nr - skip <= fetch { | ||||||||
| // after "skip" input rows are skipped, the remaining rows are less than or equal to the | ||||||||
| // "fetch" values, so `num_rows` must equal the remaining rows | ||||||||
| let remaining_rows: usize = nr - skip; | ||||||||
| let mut skip_some_rows_stats = Statistics { | ||||||||
| num_rows: Precision::Exact(remaining_rows), | ||||||||
| column_statistics: col_stats.clone(), | ||||||||
| total_byte_size: Precision::Absent, | ||||||||
| }; | ||||||||
| if !input_stats.num_rows.is_exact().unwrap_or(false) { | ||||||||
| // The input stats are inexact, so the output stats must be too. | ||||||||
| skip_some_rows_stats = skip_some_rows_stats.into_inexact(); | ||||||||
| } | ||||||||
| skip_some_rows_stats | ||||||||
| } else { | ||||||||
| // if the input is greater than the "fetch", the num_row will be the "fetch", | ||||||||
| // if the input is greater than "fetch+skip", the num_rows will be the "fetch", | ||||||||
| // but we won't be able to predict the other statistics | ||||||||
| if !input_stats.num_rows.is_exact().unwrap_or(false) | ||||||||
| || self.fetch.is_none() | ||||||||
| { | ||||||||
| // If the input stats are inexact, the output stats must be too. | ||||||||
| // If the fetch value is `usize::MAX` because no LIMIT was specified, | ||||||||
| // we also can't represent it as an exact value. | ||||||||
| fetched_row_number_stats = | ||||||||
| fetched_row_number_stats.into_inexact(); | ||||||||
| } | ||||||||
| fetched_row_number_stats | ||||||||
| } | ||||||||
| } | ||||||||
| _ => { | ||||||||
| // the result output row number will always be no greater than the limit number | ||||||||
| fetched_row_number_stats | ||||||||
| // The result output `num_rows` will always be no greater than the limit number. | ||||||||
| // Should `num_rows` be marked as `Absent` here when the `fetch` value is large, | ||||||||
| // as the actual `num_rows` may be far away from the `fetch` value? | ||||||||
| fetched_row_number_stats.into_inexact() | ||||||||
| } | ||||||||
| }; | ||||||||
| Ok(stats) | ||||||||
|
|
@@ -552,7 +574,10 @@ mod tests { | |||||||
| use crate::common::collect; | ||||||||
| use crate::{common, test}; | ||||||||
|
|
||||||||
| use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; | ||||||||
| use arrow_schema::Schema; | ||||||||
| use datafusion_physical_expr::expressions::col; | ||||||||
| use datafusion_physical_expr::PhysicalExpr; | ||||||||
|
|
||||||||
| #[tokio::test] | ||||||||
| async fn limit() -> Result<()> { | ||||||||
|
|
@@ -712,7 +737,7 @@ mod tests { | |||||||
| } | ||||||||
|
|
||||||||
| #[tokio::test] | ||||||||
| async fn skip_3_fetch_10() -> Result<()> { | ||||||||
| async fn skip_3_fetch_10_stats() -> Result<()> { | ||||||||
| // there are total of 100 rows, we skipped 3 rows (offset = 3) | ||||||||
| let row_count = skip_and_fetch(3, Some(10)).await?; | ||||||||
| assert_eq!(row_count, 10); | ||||||||
|
|
@@ -748,7 +773,58 @@ mod tests { | |||||||
| assert_eq!(row_count, Precision::Exact(10)); | ||||||||
|
|
||||||||
| let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(15)); | ||||||||
| assert_eq!(row_count, Precision::Exact(10)); | ||||||||
|
|
||||||||
| let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(0)); | ||||||||
|
|
||||||||
| let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(2)); | ||||||||
|
|
||||||||
| let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(1)); | ||||||||
|
|
||||||||
| let row_count = row_number_statistics_for_global_limit(398, None).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(2)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(400)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?; | ||||||||
| assert_eq!(row_count, Precision::Exact(2)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This threw me for quite a while while trying to figure out why a fetch of 10 resulted in inexact statistics. Now that I see the difference is the input statistics are inexact it makes much more sense |
||||||||
| row_number_inexact_statistics_for_global_limit(0, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(10)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(5, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(10)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(400, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(0)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(398, Some(10)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(2)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(398, Some(1)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(1)); | ||||||||
|
|
||||||||
| let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(2)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(400)); | ||||||||
|
|
||||||||
| let row_count = | ||||||||
| row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?; | ||||||||
| assert_eq!(row_count, Precision::Inexact(2)); | ||||||||
|
|
||||||||
| Ok(()) | ||||||||
| } | ||||||||
|
|
@@ -776,6 +852,47 @@ mod tests { | |||||||
| Ok(offset.statistics()?.num_rows) | ||||||||
| } | ||||||||
|
|
||||||||
| pub fn build_group_by( | ||||||||
| input_schema: &SchemaRef, | ||||||||
| columns: Vec<String>, | ||||||||
| ) -> PhysicalGroupBy { | ||||||||
| let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![]; | ||||||||
| for column in columns.iter() { | ||||||||
| group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); | ||||||||
| } | ||||||||
| PhysicalGroupBy::new_single(group_by_expr.clone()) | ||||||||
| } | ||||||||
|
|
||||||||
| async fn row_number_inexact_statistics_for_global_limit( | ||||||||
| skip: usize, | ||||||||
| fetch: Option<usize>, | ||||||||
| ) -> Result<Precision<usize>> { | ||||||||
| let num_partitions = 4; | ||||||||
| let csv = test::scan_partitioned(num_partitions); | ||||||||
|
|
||||||||
| assert_eq!(csv.output_partitioning().partition_count(), num_partitions); | ||||||||
|
|
||||||||
| // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. | ||||||||
| let agg = AggregateExec::try_new( | ||||||||
| AggregateMode::Final, | ||||||||
| build_group_by(&csv.schema().clone(), vec!["i".to_string()]), | ||||||||
| vec![], | ||||||||
| vec![None], | ||||||||
| vec![None], | ||||||||
| csv.clone(), | ||||||||
| csv.schema().clone(), | ||||||||
| )?; | ||||||||
| let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg); | ||||||||
|
|
||||||||
| let offset = GlobalLimitExec::new( | ||||||||
| Arc::new(CoalescePartitionsExec::new(agg_exec)), | ||||||||
| skip, | ||||||||
| fetch, | ||||||||
| ); | ||||||||
|
|
||||||||
| Ok(offset.statistics()?.num_rows) | ||||||||
| } | ||||||||
|
|
||||||||
| async fn row_number_statistics_for_local_limit( | ||||||||
| num_partitions: usize, | ||||||||
| fetch: usize, | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,7 +273,7 @@ query TT | |
| EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; | ||
| ---- | ||
| physical_plan | ||
| GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent] | ||
| GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old version is definitely misleading, nice improvement 👍 |
||
| --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent] | ||
|
|
||
| # Parquet scan with statistics collected | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -294,6 +294,91 @@ query T | |
| SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101 | ||
| ---- | ||
|
|
||
| # | ||
| # global limit statistics test | ||
| # | ||
|
|
||
| statement ok | ||
| CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10); | ||
|
|
||
| # The aggregate does not need to be computed because the input statistics are exact and | ||
| # the number of rows is less than the skip value (OFFSET). | ||
| query TT | ||
| EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); | ||
| ---- | ||
| logical_plan | ||
| Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] | ||
| --Limit: skip=11, fetch=3 | ||
| ----TableScan: t1 projection=[], fetch=14 | ||
| physical_plan | ||
| ProjectionExec: expr=[0 as COUNT(*)] | ||
| --EmptyExec: produce_one_row=true | ||
|
|
||
| query I | ||
| SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); | ||
| ---- | ||
| 0 | ||
|
|
||
| # The aggregate does not need to be computed because the input statistics are exact and | ||
| # the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). | ||
| query TT | ||
| EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); | ||
| ---- | ||
| logical_plan | ||
| Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] | ||
| --Limit: skip=8, fetch=3 | ||
| ----TableScan: t1 projection=[], fetch=11 | ||
| physical_plan | ||
| ProjectionExec: expr=[2 as COUNT(*)] | ||
| --EmptyExec: produce_one_row=true | ||
|
|
||
| query I | ||
| SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); | ||
| ---- | ||
| 2 | ||
|
|
||
| # The aggregate does not need to be computed because the input statistics are exact and | ||
| # an OFFSET, but no LIMIT, is specified. | ||
| query TT | ||
| EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8); | ||
| ---- | ||
| logical_plan | ||
| Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] | ||
| --Limit: skip=8, fetch=None | ||
| ----TableScan: t1 projection=[] | ||
| physical_plan | ||
| ProjectionExec: expr=[2 as COUNT(*)] | ||
| --EmptyExec: produce_one_row=true | ||
|
|
||
| query I | ||
| SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); | ||
| ---- | ||
| 2 | ||
|
|
||
| # The aggregate needs to be computed because the input statistics are inexact. | ||
| query TT | ||
| EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); | ||
| ---- | ||
| logical_plan | ||
| Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]] | ||
| --Limit: skip=6, fetch=3 | ||
| ----Filter: t1.a > Int32(3) | ||
| ------TableScan: t1 projection=[a] | ||
| physical_plan | ||
| AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] | ||
| --CoalescePartitionsExec | ||
| ----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] | ||
| ------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 | ||
| --------GlobalLimitExec: skip=6, fetch=3 | ||
| ----------CoalesceBatchesExec: target_batch_size=8192 | ||
| ------------FilterExec: a@0 > 3 | ||
| --------------MemoryExec: partitions=1, partition_sizes=[1] | ||
|
|
||
| query I | ||
| SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); | ||
| ---- | ||
| 1 | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| ######## | ||
| # Clean up after the test | ||
| ######## | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing the error and fixing it. I think the new logic here is quite reasonable.