Skip to content

Commit 7d24567

Browse files
authored
Fix date32 and date64 parquet row group pruning, tests for same (#690)
1 parent 63727df commit 7d24567

File tree

3 files changed

+193
-23
lines changed

3 files changed

+193
-23
lines changed

datafusion/src/physical_plan/expressions/binary.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ macro_rules! binary_array_op_scalar {
269269
DataType::Date32 => {
270270
compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
271271
}
272+
DataType::Date64 => {
273+
compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
274+
}
272275
other => Err(DataFusionError::Internal(format!(
273276
"Data type {:?} not supported for scalar operation on dyn array",
274277
other

datafusion/src/scalar.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ impl TryFrom<ScalarValue> for i64 {
900900
fn try_from(value: ScalarValue) -> Result<Self> {
901901
match value {
902902
ScalarValue::Int64(Some(inner_value))
903+
| ScalarValue::Date64(Some(inner_value))
903904
| ScalarValue::TimestampNanosecond(Some(inner_value))
904905
| ScalarValue::TimestampMicrosecond(Some(inner_value))
905906
| ScalarValue::TimestampMillisecond(Some(inner_value))
@@ -939,6 +940,8 @@ impl TryFrom<&DataType> for ScalarValue {
939940
DataType::UInt64 => ScalarValue::UInt64(None),
940941
DataType::Utf8 => ScalarValue::Utf8(None),
941942
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
943+
DataType::Date32 => ScalarValue::Date32(None),
944+
DataType::Date64 => ScalarValue::Date64(None),
942945
DataType::Timestamp(TimeUnit::Second, _) => {
943946
ScalarValue::TimestampSecond(None)
944947
}

datafusion/tests/parquet_pruning.rs

Lines changed: 187 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,28 @@ use std::sync::Arc;
2121

2222
use arrow::{
2323
array::{
24-
Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
25-
TimestampNanosecondArray, TimestampSecondArray,
24+
Array, Date32Array, Date64Array, StringArray, TimestampMicrosecondArray,
25+
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
2626
},
2727
datatypes::{Field, Schema},
2828
record_batch::RecordBatch,
2929
util::pretty::pretty_format_batches,
3030
};
31-
use chrono::Duration;
31+
use chrono::{Datelike, Duration};
3232
use datafusion::{
33+
datasource::{parquet::ParquetTable, TableProvider},
34+
logical_plan::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder},
3335
physical_plan::{plan_metrics, SQLMetric},
3436
prelude::ExecutionContext,
37+
scalar::ScalarValue,
3538
};
3639
use hashbrown::HashMap;
3740
use parquet::{arrow::ArrowWriter, file::properties::WriterProperties};
3841
use tempfile::NamedTempFile;
3942

4043
#[tokio::test]
4144
async fn prune_timestamps_nanos() {
42-
let output = ContextWithParquet::new()
45+
let output = ContextWithParquet::new(Scenario::Timestamps)
4346
.await
4447
.query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')")
4548
.await;
@@ -52,7 +55,7 @@ async fn prune_timestamps_nanos() {
5255

5356
#[tokio::test]
5457
async fn prune_timestamps_micros() {
55-
let output = ContextWithParquet::new()
58+
let output = ContextWithParquet::new(Scenario::Timestamps)
5659
.await
5760
.query(
5861
"SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')",
@@ -67,7 +70,7 @@ async fn prune_timestamps_micros() {
6770

6871
#[tokio::test]
6972
async fn prune_timestamps_millis() {
70-
let output = ContextWithParquet::new()
73+
let output = ContextWithParquet::new(Scenario::Timestamps)
7174
.await
7275
.query(
7376
"SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')",
@@ -82,7 +85,7 @@ async fn prune_timestamps_millis() {
8285

8386
#[tokio::test]
8487
async fn prune_timestamps_seconds() {
85-
let output = ContextWithParquet::new()
88+
let output = ContextWithParquet::new(Scenario::Timestamps)
8689
.await
8790
.query(
8891
"SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')",
@@ -95,15 +98,60 @@ async fn prune_timestamps_seconds() {
9598
assert_eq!(output.result_rows, 10, "{}", output.description());
9699
}
97100

101+
#[tokio::test]
102+
async fn prune_date32() {
103+
let output = ContextWithParquet::new(Scenario::Dates)
104+
.await
105+
.query("SELECT * FROM t where date32 < cast('2020-01-02' as date)")
106+
.await;
107+
println!("{}", output.description());
108+
// This should prune out groups without error
109+
assert_eq!(output.predicate_evaluation_errors(), Some(0));
110+
assert_eq!(output.row_groups_pruned(), Some(3));
111+
assert_eq!(output.result_rows, 1, "{}", output.description());
112+
}
113+
114+
#[tokio::test]
115+
async fn prune_date64() {
116+
// work around for not being able to cast Date32 to Date64 automatically
117+
let date = "2020-01-02"
118+
.parse::<chrono::NaiveDate>()
119+
.unwrap()
120+
.and_time(chrono::NaiveTime::from_hms(0, 0, 0));
121+
let date = ScalarValue::Date64(Some(date.timestamp_millis()));
122+
123+
let output = ContextWithParquet::new(Scenario::Dates)
124+
.await
125+
.query_with_expr(col("date64").lt(lit(date)))
126+
// .query(
127+
// "SELECT * FROM t where date64 < caste('2020-01-02' as date)",
128+
// query results in Plan("'Date64 < Date32' can't be evaluated because there isn't a common type to coerce the types to")
129+
// )
130+
.await;
131+
132+
println!("{}", output.description());
133+
// This should prune out groups without error
134+
assert_eq!(output.predicate_evaluation_errors(), Some(0));
135+
assert_eq!(output.row_groups_pruned(), Some(3));
136+
assert_eq!(output.result_rows, 1, "{}", output.description());
137+
}
138+
98139
// ----------------------
99140
// Begin test fixture
100141
// ----------------------
101142

143+
/// What data to use
144+
enum Scenario {
145+
Timestamps,
146+
Dates,
147+
}
148+
102149
/// Test fixture that has an execution context that has an external
103150
/// table "t" registered, pointing at a parquet file made with
104151
/// `make_test_file`
105152
struct ContextWithParquet {
106153
file: NamedTempFile,
154+
provider: Arc<dyn TableProvider>,
107155
ctx: ExecutionContext,
108156
}
109157

@@ -156,24 +204,54 @@ impl TestOutput {
156204

157205
/// Creates an execution context that has an external table "t"
158206
/// registered pointing at a parquet file made with `make_test_file`
207+
/// and the appropriate scenario
159208
impl ContextWithParquet {
160-
async fn new() -> Self {
161-
let file = make_test_file().await;
209+
async fn new(scenario: Scenario) -> Self {
210+
let file = make_test_file(scenario).await;
162211

163212
// now, setup a the file as a data source and run a query against it
164213
let mut ctx = ExecutionContext::new();
165214
let parquet_path = file.path().to_string_lossy();
166-
ctx.register_parquet("t", &parquet_path)
167-
.expect("registering");
168215

169-
Self { file, ctx }
216+
let table = ParquetTable::try_new(parquet_path, 4).unwrap();
217+
218+
let provider = Arc::new(table);
219+
ctx.register_table("t", provider.clone()).unwrap();
220+
221+
Self {
222+
file,
223+
provider,
224+
ctx,
225+
}
226+
}
227+
228+
/// runs a query like "SELECT * from t WHERE <expr> and returns
229+
/// the number of output rows and normalized execution metrics
230+
async fn query_with_expr(&mut self, expr: Expr) -> TestOutput {
231+
let sql = format!("EXPR only: {:?}", expr);
232+
let logical_plan = LogicalPlanBuilder::scan("t", self.provider.clone(), None)
233+
.unwrap()
234+
.filter(expr)
235+
.unwrap()
236+
.build()
237+
.unwrap();
238+
self.run_test(logical_plan, sql).await
170239
}
171240

172241
/// Runs the specified SQL query and returns the number of output
173242
/// rows and normalized execution metrics
174243
async fn query(&mut self, sql: &str) -> TestOutput {
175244
println!("Planning sql {}", sql);
245+
let logical_plan = self.ctx.sql(sql).expect("planning").to_logical_plan();
246+
self.run_test(logical_plan, sql).await
247+
}
176248

249+
/// runs the logical plan
250+
async fn run_test(
251+
&mut self,
252+
logical_plan: LogicalPlan,
253+
sql: impl Into<String>,
254+
) -> TestOutput {
177255
let input = self
178256
.ctx
179257
.sql("SELECT * from t")
@@ -183,8 +261,6 @@ impl ContextWithParquet {
183261
.expect("getting input");
184262
let pretty_input = pretty_format_batches(&input).unwrap();
185263

186-
let logical_plan = self.ctx.sql(sql).expect("planning").to_logical_plan();
187-
188264
let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan");
189265
let execution_plan = self
190266
.ctx
@@ -210,7 +286,7 @@ impl ContextWithParquet {
210286

211287
let pretty_results = pretty_format_batches(&results).unwrap();
212288

213-
let sql = sql.to_string();
289+
let sql = sql.into();
214290
TestOutput {
215291
sql,
216292
metrics,
@@ -222,7 +298,7 @@ impl ContextWithParquet {
222298
}
223299

224300
/// Create a test parquet file with varioud data types
225-
async fn make_test_file() -> NamedTempFile {
301+
async fn make_test_file(scenario: Scenario) -> NamedTempFile {
226302
let output_file = tempfile::Builder::new()
227303
.prefix("parquet_pruning")
228304
.suffix(".parquet")
@@ -233,12 +309,25 @@ async fn make_test_file() -> NamedTempFile {
233309
.set_max_row_group_size(5)
234310
.build();
235311

236-
let batches = vec![
237-
make_batch(Duration::seconds(0)),
238-
make_batch(Duration::seconds(10)),
239-
make_batch(Duration::minutes(10)),
240-
make_batch(Duration::days(10)),
241-
];
312+
let batches = match scenario {
313+
Scenario::Timestamps => {
314+
vec![
315+
make_timestamp_batch(Duration::seconds(0)),
316+
make_timestamp_batch(Duration::seconds(10)),
317+
make_timestamp_batch(Duration::minutes(10)),
318+
make_timestamp_batch(Duration::days(10)),
319+
]
320+
}
321+
Scenario::Dates => {
322+
vec![
323+
make_date_batch(Duration::days(0)),
324+
make_date_batch(Duration::days(10)),
325+
make_date_batch(Duration::days(300)),
326+
make_date_batch(Duration::days(3600)),
327+
]
328+
}
329+
};
330+
242331
let schema = batches[0].schema();
243332

244333
let mut writer = ArrowWriter::try_new(
@@ -268,7 +357,7 @@ async fn make_test_file() -> NamedTempFile {
268357
/// "millis" --> TimestampMillisecondArray
269358
/// "seconds" --> TimestampSecondArray
270359
/// "names" --> StringArray
271-
pub fn make_batch(offset: Duration) -> RecordBatch {
360+
fn make_timestamp_batch(offset: Duration) -> RecordBatch {
272361
let ts_strings = vec![
273362
Some("2020-01-01T01:01:01.0000000000001"),
274363
Some("2020-01-01T01:02:01.0000000000001"),
@@ -341,3 +430,78 @@ pub fn make_batch(offset: Duration) -> RecordBatch {
341430
)
342431
.unwrap()
343432
}
433+
434+
/// Return record batch with a few rows of data for all of the supported date
435+
/// types with the specified offset (in days)
436+
///
437+
/// Columns are named:
438+
/// "date32" --> Date32Array
439+
/// "date64" --> Date64Array
440+
/// "names" --> StringArray
441+
fn make_date_batch(offset: Duration) -> RecordBatch {
442+
let date_strings = vec![
443+
Some("2020-01-01"),
444+
Some("2020-01-02"),
445+
Some("2020-01-03"),
446+
None,
447+
Some("2020-01-04"),
448+
];
449+
450+
let names = date_strings
451+
.iter()
452+
.enumerate()
453+
.map(|(i, val)| format!("Row {} + {}: {:?}", i, offset, val))
454+
.collect::<Vec<_>>();
455+
456+
// Copied from `cast.rs` cast kernel due to lack of temporal kernels
457+
// https://github.com/apache/arrow-rs/issues/527
458+
const EPOCH_DAYS_FROM_CE: i32 = 719_163;
459+
460+
let date_seconds = date_strings
461+
.iter()
462+
.map(|t| {
463+
t.map(|t| {
464+
let t = t.parse::<chrono::NaiveDate>().unwrap();
465+
let t = t + offset;
466+
t.num_days_from_ce() - EPOCH_DAYS_FROM_CE
467+
})
468+
})
469+
.collect::<Vec<_>>();
470+
471+
let date_millis = date_strings
472+
.into_iter()
473+
.map(|t| {
474+
t.map(|t| {
475+
let t = t
476+
.parse::<chrono::NaiveDate>()
477+
.unwrap()
478+
.and_time(chrono::NaiveTime::from_hms(0, 0, 0));
479+
let t = t + offset;
480+
t.timestamp_millis()
481+
})
482+
})
483+
.collect::<Vec<_>>();
484+
485+
let arr_date32 = Date32Array::from(date_seconds);
486+
let arr_date64 = Date64Array::from(date_millis);
487+
488+
let names = names.iter().map(|s| s.as_str()).collect::<Vec<_>>();
489+
let arr_names = StringArray::from(names);
490+
491+
let schema = Schema::new(vec![
492+
Field::new("date32", arr_date32.data_type().clone(), true),
493+
Field::new("date64", arr_date64.data_type().clone(), true),
494+
Field::new("name", arr_names.data_type().clone(), true),
495+
]);
496+
let schema = Arc::new(schema);
497+
498+
RecordBatch::try_new(
499+
schema,
500+
vec![
501+
Arc::new(arr_date32),
502+
Arc::new(arr_date64),
503+
Arc::new(arr_names),
504+
],
505+
)
506+
.unwrap()
507+
}

0 commit comments

Comments
 (0)