Skip to content

Commit 06fd26b

Browse files
msirekMark Sirek
andauthored
Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
Co-authored-by: Mark Sirek <sirek@cockroachlabs.com>
1 parent 07c08a3 commit 06fd26b

File tree

4 files changed

+232
-26
lines changed

4 files changed

+232
-26
lines changed

datafusion/physical-plan/src/limit.rs

Lines changed: 138 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec {
188188
fn statistics(&self) -> Result<Statistics> {
189189
let input_stats = self.input.statistics()?;
190190
let skip = self.skip;
191-
// the maximum row number needs to be fetched
192-
let max_row_num = self
193-
.fetch
194-
.map(|fetch| {
195-
if fetch >= usize::MAX - skip {
196-
usize::MAX
197-
} else {
198-
fetch + skip
199-
}
200-
})
201-
.unwrap_or(usize::MAX);
202191
let col_stats = Statistics::unknown_column(&self.schema());
192+
let fetch = self.fetch.unwrap_or(usize::MAX);
203193

204-
let fetched_row_number_stats = Statistics {
205-
num_rows: Precision::Exact(max_row_num),
194+
let mut fetched_row_number_stats = Statistics {
195+
num_rows: Precision::Exact(fetch),
206196
column_statistics: col_stats.clone(),
207197
total_byte_size: Precision::Absent,
208198
};
@@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec {
218208
} => {
219209
if nr <= skip {
220210
// if all input data will be skipped, return 0
221-
Statistics {
211+
let mut skip_all_rows_stats = Statistics {
222212
num_rows: Precision::Exact(0),
223213
column_statistics: col_stats,
224214
total_byte_size: Precision::Absent,
215+
};
216+
if !input_stats.num_rows.is_exact().unwrap_or(false) {
217+
// The input stats are inexact, so the output stats must be too.
218+
skip_all_rows_stats = skip_all_rows_stats.into_inexact();
225219
}
226-
} else if nr <= max_row_num {
227-
// if the input does not reach the "fetch" globally, return input stats
220+
skip_all_rows_stats
221+
} else if nr <= fetch && self.skip == 0 {
222+
// if the input does not reach the "fetch" globally, and "skip" is zero
223+
// (meaning the input and output are identical), return input stats.
224+
// Can input_stats still be used, but adjusted, in the "skip != 0" case?
228225
input_stats
226+
} else if nr - skip <= fetch {
227+
// after "skip" input rows are skipped, the remaining rows are less than or equal to the
228+
// "fetch" values, so `num_rows` must equal the remaining rows
229+
let remaining_rows: usize = nr - skip;
230+
let mut skip_some_rows_stats = Statistics {
231+
num_rows: Precision::Exact(remaining_rows),
232+
column_statistics: col_stats.clone(),
233+
total_byte_size: Precision::Absent,
234+
};
235+
if !input_stats.num_rows.is_exact().unwrap_or(false) {
236+
// The input stats are inexact, so the output stats must be too.
237+
skip_some_rows_stats = skip_some_rows_stats.into_inexact();
238+
}
239+
skip_some_rows_stats
229240
} else {
230-
// if the input is greater than the "fetch", the num_row will be the "fetch",
241+
// if the input is greater than "fetch+skip", the num_rows will be the "fetch",
231242
// but we won't be able to predict the other statistics
243+
if !input_stats.num_rows.is_exact().unwrap_or(false)
244+
|| self.fetch.is_none()
245+
{
246+
// If the input stats are inexact, the output stats must be too.
247+
// If the fetch value is `usize::MAX` because no LIMIT was specified,
248+
// we also can't represent it as an exact value.
249+
fetched_row_number_stats =
250+
fetched_row_number_stats.into_inexact();
251+
}
232252
fetched_row_number_stats
233253
}
234254
}
235255
_ => {
236-
// the result output row number will always be no greater than the limit number
237-
fetched_row_number_stats
256+
// The result output `num_rows` will always be no greater than the limit number.
257+
// Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
258+
// as the actual `num_rows` may be far away from the `fetch` value?
259+
fetched_row_number_stats.into_inexact()
238260
}
239261
};
240262
Ok(stats)
@@ -552,7 +574,10 @@ mod tests {
552574
use crate::common::collect;
553575
use crate::{common, test};
554576

577+
use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
555578
use arrow_schema::Schema;
579+
use datafusion_physical_expr::expressions::col;
580+
use datafusion_physical_expr::PhysicalExpr;
556581

557582
#[tokio::test]
558583
async fn limit() -> Result<()> {
@@ -712,7 +737,7 @@ mod tests {
712737
}
713738

714739
#[tokio::test]
715-
async fn skip_3_fetch_10() -> Result<()> {
740+
async fn skip_3_fetch_10_stats() -> Result<()> {
716741
// there are total of 100 rows, we skipped 3 rows (offset = 3)
717742
let row_count = skip_and_fetch(3, Some(10)).await?;
718743
assert_eq!(row_count, 10);
@@ -748,7 +773,58 @@ mod tests {
748773
assert_eq!(row_count, Precision::Exact(10));
749774

750775
let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
751-
assert_eq!(row_count, Precision::Exact(15));
776+
assert_eq!(row_count, Precision::Exact(10));
777+
778+
let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
779+
assert_eq!(row_count, Precision::Exact(0));
780+
781+
let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
782+
assert_eq!(row_count, Precision::Exact(2));
783+
784+
let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
785+
assert_eq!(row_count, Precision::Exact(1));
786+
787+
let row_count = row_number_statistics_for_global_limit(398, None).await?;
788+
assert_eq!(row_count, Precision::Exact(2));
789+
790+
let row_count =
791+
row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
792+
assert_eq!(row_count, Precision::Exact(400));
793+
794+
let row_count =
795+
row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
796+
assert_eq!(row_count, Precision::Exact(2));
797+
798+
let row_count =
799+
row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
800+
assert_eq!(row_count, Precision::Inexact(10));
801+
802+
let row_count =
803+
row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
804+
assert_eq!(row_count, Precision::Inexact(10));
805+
806+
let row_count =
807+
row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
808+
assert_eq!(row_count, Precision::Inexact(0));
809+
810+
let row_count =
811+
row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
812+
assert_eq!(row_count, Precision::Inexact(2));
813+
814+
let row_count =
815+
row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
816+
assert_eq!(row_count, Precision::Inexact(1));
817+
818+
let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
819+
assert_eq!(row_count, Precision::Inexact(2));
820+
821+
let row_count =
822+
row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
823+
assert_eq!(row_count, Precision::Inexact(400));
824+
825+
let row_count =
826+
row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
827+
assert_eq!(row_count, Precision::Inexact(2));
752828

753829
Ok(())
754830
}
@@ -776,6 +852,47 @@ mod tests {
776852
Ok(offset.statistics()?.num_rows)
777853
}
778854

855+
pub fn build_group_by(
856+
input_schema: &SchemaRef,
857+
columns: Vec<String>,
858+
) -> PhysicalGroupBy {
859+
let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
860+
for column in columns.iter() {
861+
group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
862+
}
863+
PhysicalGroupBy::new_single(group_by_expr.clone())
864+
}
865+
866+
async fn row_number_inexact_statistics_for_global_limit(
867+
skip: usize,
868+
fetch: Option<usize>,
869+
) -> Result<Precision<usize>> {
870+
let num_partitions = 4;
871+
let csv = test::scan_partitioned(num_partitions);
872+
873+
assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
874+
875+
// Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
876+
let agg = AggregateExec::try_new(
877+
AggregateMode::Final,
878+
build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
879+
vec![],
880+
vec![None],
881+
vec![None],
882+
csv.clone(),
883+
csv.schema().clone(),
884+
)?;
885+
let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
886+
887+
let offset = GlobalLimitExec::new(
888+
Arc::new(CoalescePartitionsExec::new(agg_exec)),
889+
skip,
890+
fetch,
891+
);
892+
893+
Ok(offset.statistics()?.num_rows)
894+
}
895+
779896
async fn row_number_statistics_for_local_limit(
780897
num_partitions: usize,
781898
fetch: usize,

datafusion/sqllogictest/test_files/explain.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ query TT
273273
EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10;
274274
----
275275
physical_plan
276-
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent]
276+
GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent]
277277
--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent]
278278

279279
# Parquet scan with statistics collected

datafusion/sqllogictest/test_files/limit.slt

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,91 @@ query T
294294
SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101
295295
----
296296

297+
#
298+
# global limit statistics test
299+
#
300+
301+
statement ok
302+
CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10);
303+
304+
# The aggregate does not need to be computed because the input statistics are exact and
305+
# the number of rows is less than the skip value (OFFSET).
306+
query TT
307+
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
308+
----
309+
logical_plan
310+
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
311+
--Limit: skip=11, fetch=3
312+
----TableScan: t1 projection=[], fetch=14
313+
physical_plan
314+
ProjectionExec: expr=[0 as COUNT(*)]
315+
--EmptyExec: produce_one_row=true
316+
317+
query I
318+
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
319+
----
320+
0
321+
322+
# The aggregate does not need to be computed because the input statistics are exact and
323+
# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET).
324+
query TT
325+
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
326+
----
327+
logical_plan
328+
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
329+
--Limit: skip=8, fetch=3
330+
----TableScan: t1 projection=[], fetch=11
331+
physical_plan
332+
ProjectionExec: expr=[2 as COUNT(*)]
333+
--EmptyExec: produce_one_row=true
334+
335+
query I
336+
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
337+
----
338+
2
339+
340+
# The aggregate does not need to be computed because the input statistics are exact and
341+
# an OFFSET, but no LIMIT, is specified.
342+
query TT
343+
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8);
344+
----
345+
logical_plan
346+
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
347+
--Limit: skip=8, fetch=None
348+
----TableScan: t1 projection=[]
349+
physical_plan
350+
ProjectionExec: expr=[2 as COUNT(*)]
351+
--EmptyExec: produce_one_row=true
352+
353+
query I
354+
SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
355+
----
356+
2
357+
358+
# The aggregate needs to be computed because the input statistics are inexact.
359+
query TT
360+
EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
361+
----
362+
logical_plan
363+
Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
364+
--Limit: skip=6, fetch=3
365+
----Filter: t1.a > Int32(3)
366+
------TableScan: t1 projection=[a]
367+
physical_plan
368+
AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)]
369+
--CoalescePartitionsExec
370+
----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)]
371+
------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
372+
--------GlobalLimitExec: skip=6, fetch=3
373+
----------CoalesceBatchesExec: target_batch_size=8192
374+
------------FilterExec: a@0 > 3
375+
--------------MemoryExec: partitions=1, partition_sizes=[1]
376+
377+
query I
378+
SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
379+
----
380+
1
381+
297382
########
298383
# Clean up after the test
299384
########

datafusion/sqllogictest/test_files/window.slt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,10 +2010,14 @@ Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1
20102010
--------TableScan: aggregate_test_100 projection=[c13]
20112011
physical_plan
20122012
ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1]
2013-
--AggregateExec: mode=Single, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
2014-
----GlobalLimitExec: skip=0, fetch=1
2015-
------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
2016-
--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
2013+
--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
2014+
----CoalescePartitionsExec
2015+
------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
2016+
--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
2017+
----------GlobalLimitExec: skip=0, fetch=1
2018+
------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
2019+
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
2020+
20172021

20182022
query ?
20192023
SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1)

0 commit comments

Comments
 (0)