Skip to content

Commit 82ab9df

Browse files
authored
Merge branch 'main' into fast_sort_with_inlined_fast_key
2 parents d6650a7 + 9bb309c commit 82ab9df

File tree

10 files changed

+162
-82
lines changed

10 files changed

+162
-82
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ use std::mem::{size_of, size_of_val};
3232
use std::str::FromStr;
3333
use std::sync::Arc;
3434

35-
use crate::arrow_datafusion_err;
3635
use crate::cast::{
3736
as_decimal128_array, as_decimal256_array, as_dictionary_array,
3837
as_fixed_size_binary_array, as_fixed_size_list_array,
@@ -41,6 +40,7 @@ use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_
4140
use crate::format::DEFAULT_CAST_OPTIONS;
4241
use crate::hash_utils::create_hashes;
4342
use crate::utils::SingleRowListArrayBuilder;
43+
use crate::{_internal_datafusion_err, arrow_datafusion_err};
4444
use arrow::array::{
4545
types::{IntervalDayTime, IntervalMonthDayNano},
4646
*,
@@ -1849,10 +1849,6 @@ impl ScalarValue {
18491849
/// Returns an error if the iterator is empty or if the
18501850
/// [`ScalarValue`]s are not all the same type
18511851
///
1852-
/// # Panics
1853-
///
1854-
/// Panics if `self` is a dictionary with invalid key type
1855-
///
18561852
/// # Example
18571853
/// ```
18581854
/// use datafusion_common::ScalarValue;
@@ -3343,6 +3339,16 @@ impl ScalarValue {
33433339
arr1 == &right
33443340
}
33453341

3342+
/// Compare `self` with `other` and return an `Ordering`.
3343+
///
3344+
/// This is the same as [`PartialOrd`] except that it returns
3345+
/// `Err` if the values cannot be compared, e.g., they have incompatible data types.
3346+
pub fn try_cmp(&self, other: &Self) -> Result<Ordering> {
3347+
self.partial_cmp(other).ok_or_else(|| {
3348+
_internal_datafusion_err!("Uncomparable values: {self:?}, {other:?}")
3349+
})
3350+
}
3351+
33463352
/// Estimate size if bytes including `Self`. For values with internal containers such as `String`
33473353
/// includes the allocated size (`capacity`) rather than the current length (`len`)
33483354
pub fn size(&self) -> usize {
@@ -4761,6 +4767,32 @@ mod tests {
47614767
Ok(())
47624768
}
47634769

4770+
#[test]
4771+
fn test_try_cmp() {
4772+
assert_eq!(
4773+
ScalarValue::try_cmp(
4774+
&ScalarValue::Int32(Some(1)),
4775+
&ScalarValue::Int32(Some(2))
4776+
)
4777+
.unwrap(),
4778+
Ordering::Less
4779+
);
4780+
assert_eq!(
4781+
ScalarValue::try_cmp(&ScalarValue::Int32(None), &ScalarValue::Int32(Some(2)))
4782+
.unwrap(),
4783+
Ordering::Less
4784+
);
4785+
assert_starts_with(
4786+
ScalarValue::try_cmp(
4787+
&ScalarValue::Int32(Some(1)),
4788+
&ScalarValue::Int64(Some(2)),
4789+
)
4790+
.unwrap_err()
4791+
.message(),
4792+
"Uncomparable values: Int32(1), Int64(2)",
4793+
);
4794+
}
4795+
47644796
#[test]
47654797
fn scalar_decimal_test() -> Result<()> {
47664798
let decimal_value = ScalarValue::Decimal128(Some(123), 10, 1);
@@ -7669,4 +7701,15 @@ mod tests {
76697701
];
76707702
assert!(scalars.iter().all(|s| s.is_null()));
76717703
}
7704+
7705+
// `err.to_string()` depends on backtrace being present (may have backtrace appended)
7706+
// `err.strip_backtrace()` also depends on backtrace being present (may have "This was likely caused by ..." stripped)
7707+
fn assert_starts_with(actual: impl AsRef<str>, expected_prefix: impl AsRef<str>) {
7708+
let actual = actual.as_ref();
7709+
let expected_prefix = expected_prefix.as_ref();
7710+
assert!(
7711+
actual.starts_with(expected_prefix),
7712+
"Expected '{actual}' to start with '{expected_prefix}'"
7713+
);
7714+
}
76727715
}

datafusion/common/src/utils/mod.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub mod memory;
2222
pub mod proxy;
2323
pub mod string_utils;
2424

25-
use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
25+
use crate::error::{_exec_datafusion_err, _internal_err};
2626
use crate::{DataFusionError, Result, ScalarValue};
2727
use arrow::array::{
2828
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
@@ -120,14 +120,13 @@ pub fn compare_rows(
120120
let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) {
121121
(true, false, false) | (false, true, true) => Ordering::Greater,
122122
(true, false, true) | (false, true, false) => Ordering::Less,
123-
(false, false, _) => if sort_options.descending {
124-
rhs.partial_cmp(lhs)
125-
} else {
126-
lhs.partial_cmp(rhs)
123+
(false, false, _) => {
124+
if sort_options.descending {
125+
rhs.try_cmp(lhs)?
126+
} else {
127+
lhs.try_cmp(rhs)?
128+
}
127129
}
128-
.ok_or_else(|| {
129-
_internal_datafusion_err!("Column array shouldn't be empty")
130-
})?,
131130
(true, true, _) => continue,
132131
};
133132
if result != Ordering::Equal {

datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ use datafusion_execution::memory_pool::{
3131
};
3232
use datafusion_expr::display_schema;
3333
use datafusion_physical_plan::spill::get_record_batch_memory_size;
34-
use itertools::Itertools;
3534
use std::time::Duration;
3635

3736
use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder};
@@ -73,43 +72,6 @@ async fn sort_query_fuzzer_runner() {
7372
fuzzer.run().await.unwrap();
7473
}
7574

76-
/// Reproduce the bug with specific seeds from the
77-
/// [failing test case](https://github.com/apache/datafusion/issues/16452).
78-
#[tokio::test(flavor = "multi_thread")]
79-
async fn test_reproduce_sort_query_issue_16452() {
80-
// Seeds from the failing test case
81-
let init_seed = 10313160656544581998u64;
82-
let query_seed = 15004039071976572201u64;
83-
let config_seed_1 = 11807432710583113300u64;
84-
let config_seed_2 = 759937414670321802u64;
85-
86-
let random_seed = 1u64; // Use a fixed seed to ensure consistent behavior
87-
88-
let mut test_generator = SortFuzzerTestGenerator::new(
89-
2000,
90-
3,
91-
"sort_fuzz_table".to_string(),
92-
get_supported_types_columns(random_seed),
93-
false,
94-
random_seed,
95-
);
96-
97-
let mut results = vec![];
98-
99-
for config_seed in [config_seed_1, config_seed_2] {
100-
let r = test_generator
101-
.fuzzer_run(init_seed, query_seed, config_seed)
102-
.await
103-
.unwrap();
104-
105-
results.push(r);
106-
}
107-
108-
for (lhs, rhs) in results.iter().tuple_windows() {
109-
check_equality_of_batches(lhs, rhs).unwrap();
110-
}
111-
}
112-
11375
/// SortQueryFuzzer holds the runner configuration for executing sort query fuzz tests. The fuzzing details are managed inside `SortFuzzerTestGenerator`.
11476
///
11577
/// It defines:
@@ -466,7 +428,7 @@ impl SortFuzzerTestGenerator {
466428
.collect();
467429

468430
let mut order_by_clauses = Vec::new();
469-
for col in selected_columns {
431+
for col in &selected_columns {
470432
let mut clause = col.name.clone();
471433
if rng.random_bool(0.5) {
472434
let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" };
@@ -501,7 +463,12 @@ impl SortFuzzerTestGenerator {
501463
let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}"));
502464

503465
let query = format!(
504-
"SELECT * FROM {} ORDER BY {}{}",
466+
"SELECT {} FROM {} ORDER BY {}{}",
467+
selected_columns
468+
.iter()
469+
.map(|col| col.name.clone())
470+
.collect::<Vec<_>>()
471+
.join(", "),
505472
self.table_name,
506473
order_by_clauses.join(", "),
507474
limit_clause

datafusion/core/tests/sql/aggregates.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
5252

5353
// workaround lack of Ord of ScalarValue
5454
let cmp = |a: &ScalarValue, b: &ScalarValue| {
55-
a.partial_cmp(b).expect("Can compare ScalarValues")
55+
a.try_cmp(b).expect("Can compare ScalarValues")
5656
};
5757
scalars.sort_by(cmp);
5858
assert_eq!(

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,9 @@ fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarV
291291
extreme = current;
292292
continue;
293293
}
294-
if let Some(cmp) = extreme.partial_cmp(&current) {
295-
if cmp == ordering {
296-
extreme = current;
297-
}
294+
let cmp = extreme.try_cmp(&current)?;
295+
if cmp == ordering {
296+
extreme = current;
298297
}
299298
}
300299

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
461461
}
462462

463463
if let Some(opts) = self.sort_options {
464+
let mut delayed_cmp_err = Ok(());
464465
values.sort_by(|a, b| {
465466
if a.is_null() {
466467
return match opts.nulls_first {
@@ -475,10 +476,15 @@ impl Accumulator for DistinctArrayAggAccumulator {
475476
};
476477
}
477478
match opts.descending {
478-
true => b.partial_cmp(a).unwrap_or(Ordering::Equal),
479-
false => a.partial_cmp(b).unwrap_or(Ordering::Equal),
479+
true => b.try_cmp(a),
480+
false => a.try_cmp(b),
480481
}
482+
.unwrap_or_else(|err| {
483+
delayed_cmp_err = Err(err);
484+
Ordering::Equal
485+
})
481486
});
487+
delayed_cmp_err?;
482488
};
483489

484490
let arr = ScalarValue::new_list(&values, &self.datatype, true);

datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,16 @@ fn find_most_restrictive_predicate(
204204

205205
if let Some(scalar) = scalar_value {
206206
if let Some(current_best) = best_value {
207-
if let Some(comparison) = scalar.partial_cmp(current_best) {
208-
let is_better = if find_greater {
209-
comparison == std::cmp::Ordering::Greater
210-
} else {
211-
comparison == std::cmp::Ordering::Less
212-
};
213-
214-
if is_better {
215-
best_value = Some(scalar);
216-
most_restrictive_idx = idx;
217-
}
207+
let comparison = scalar.try_cmp(current_best)?;
208+
let is_better = if find_greater {
209+
comparison == std::cmp::Ordering::Greater
210+
} else {
211+
comparison == std::cmp::Ordering::Less
212+
};
213+
214+
if is_better {
215+
best_value = Some(scalar);
216+
most_restrictive_idx = idx;
218217
}
219218
} else {
220219
best_value = Some(scalar);

datafusion/physical-plan/src/topk/mod.rs

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
//! TopK: Combination of Sort / LIMIT
1919
2020
use arrow::{
21-
array::Array,
22-
compute::interleave_record_batch,
21+
array::{Array, AsArray},
22+
compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder},
2323
row::{RowConverter, Rows, SortField},
2424
};
2525
use datafusion_expr::{ColumnarValue, Operator};
@@ -203,7 +203,7 @@ impl TopK {
203203
let baseline = self.metrics.baseline.clone();
204204
let _timer = baseline.elapsed_compute().timer();
205205

206-
let sort_keys: Vec<ArrayRef> = self
206+
let mut sort_keys: Vec<ArrayRef> = self
207207
.expr
208208
.iter()
209209
.map(|expr| {
@@ -212,15 +212,56 @@ impl TopK {
212212
})
213213
.collect::<Result<Vec<_>>>()?;
214214

215+
let mut selected_rows = None;
216+
217+
if let Some(filter) = self.filter.as_ref() {
218+
// If a filter is provided, update it with the new rows
219+
let filter = filter.current()?;
220+
let filtered = filter.evaluate(&batch)?;
221+
let num_rows = batch.num_rows();
222+
let array = filtered.into_array(num_rows)?;
223+
let mut filter = array.as_boolean().clone();
224+
let true_count = filter.true_count();
225+
if true_count == 0 {
226+
// nothing to filter, so no need to update
227+
return Ok(());
228+
}
229+
// only update the keys / rows if the filter does not match all rows
230+
if true_count < num_rows {
231+
// Indices in `set_indices` should be correct if filter contains nulls
232+
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
233+
// so there is no overhead to do this here.
234+
if filter.nulls().is_some() {
235+
filter = prep_null_mask_filter(&filter);
236+
}
237+
238+
let filter_predicate = FilterBuilder::new(&filter);
239+
let filter_predicate = if sort_keys.len() > 1 {
240+
// Optimize filter when it has multiple sort keys
241+
filter_predicate.optimize().build()
242+
} else {
243+
filter_predicate.build()
244+
};
245+
selected_rows = Some(filter);
246+
sort_keys = sort_keys
247+
.iter()
248+
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
249+
.collect::<Result<Vec<_>>>()?;
250+
}
251+
};
215252
// reuse existing `Rows` to avoid reallocations
216253
let rows = &mut self.scratch_rows;
217254
rows.clear();
218255
self.row_converter.append(rows, &sort_keys)?;
219256

220257
let mut batch_entry = self.heap.register_batch(batch.clone());
221258

222-
let replacements =
223-
self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry);
259+
let replacements = match selected_rows {
260+
Some(filter) => {
261+
self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
262+
}
263+
None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
264+
};
224265

225266
if replacements > 0 {
226267
self.metrics.row_replacements.add(replacements);

datafusion/sql/src/expr/mod.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ use datafusion_expr::planner::{
2121
};
2222
use sqlparser::ast::{
2323
AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType,
24-
DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry,
25-
StructField, Subscript, TrimWhereField, Value, ValueWithSpan,
24+
DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias,
25+
FunctionArguments, MapEntry, StructField, Subscript, TrimWhereField, Value,
26+
ValueWithSpan,
2627
};
2728

2829
use datafusion_common::{
@@ -476,7 +477,21 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
476477
),
477478

478479
SQLExpr::Function(function) => {
479-
self.sql_function_to_expr(function, schema, planner_context)
480+
// workaround for https://github.com/apache/datafusion-sqlparser-rs/issues/1909
481+
if matches!(function.args, FunctionArguments::None)
482+
&& function.name.0.len() > 1
483+
&& function.name.0.iter().all(|part| part.as_ident().is_some())
484+
{
485+
let ids = function
486+
.name
487+
.0
488+
.iter()
489+
.map(|part| part.as_ident().expect("just checked").clone())
490+
.collect();
491+
self.sql_compound_identifier_to_expr(ids, schema, planner_context)
492+
} else {
493+
self.sql_function_to_expr(function, schema, planner_context)
494+
}
480495
}
481496

482497
SQLExpr::Rollup(exprs) => {

0 commit comments

Comments
 (0)