Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datafusion/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ macro_rules! binary_array_op_scalar {
DataType::Date32 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array)
}
DataType::Date64 => {
compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array)
}
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for scalar operation on dyn array",
other
Expand Down
3 changes: 3 additions & 0 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,7 @@ impl TryFrom<ScalarValue> for i64 {
fn try_from(value: ScalarValue) -> Result<Self> {
match value {
ScalarValue::Int64(Some(inner_value))
| ScalarValue::Date64(Some(inner_value))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mirrors the clause above for Date32 in TryFrom<ScalarValue> for i32

| ScalarValue::TimestampNanosecond(Some(inner_value))
| ScalarValue::TimestampMicrosecond(Some(inner_value))
| ScalarValue::TimestampMillisecond(Some(inner_value))
Expand Down Expand Up @@ -939,6 +940,8 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::UInt64 => ScalarValue::UInt64(None),
DataType::Utf8 => ScalarValue::Utf8(None),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(None),
DataType::Date32 => ScalarValue::Date32(None),
DataType::Date64 => ScalarValue::Date64(None),
DataType::Timestamp(TimeUnit::Second, _) => {
ScalarValue::TimestampSecond(None)
}
Expand Down
210 changes: 187 additions & 23 deletions datafusion/tests/parquet_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,28 @@ use std::sync::Arc;

use arrow::{
array::{
Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
Array, Date32Array, Date64Array, StringArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{Field, Schema},
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use chrono::Duration;
use chrono::{Datelike, Duration};
use datafusion::{
datasource::{parquet::ParquetTable, TableProvider},
logical_plan::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder},
physical_plan::{plan_metrics, SQLMetric},
prelude::ExecutionContext,
scalar::ScalarValue,
};
use hashbrown::HashMap;
use parquet::{arrow::ArrowWriter, file::properties::WriterProperties};
use tempfile::NamedTempFile;

#[tokio::test]
async fn prune_timestamps_nanos() {
let output = ContextWithParquet::new()
let output = ContextWithParquet::new(Scenario::Timestamps)
.await
.query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')")
.await;
Expand All @@ -52,7 +55,7 @@ async fn prune_timestamps_nanos() {

#[tokio::test]
async fn prune_timestamps_micros() {
let output = ContextWithParquet::new()
let output = ContextWithParquet::new(Scenario::Timestamps)
.await
.query(
"SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')",
Expand All @@ -67,7 +70,7 @@ async fn prune_timestamps_micros() {

#[tokio::test]
async fn prune_timestamps_millis() {
let output = ContextWithParquet::new()
let output = ContextWithParquet::new(Scenario::Timestamps)
.await
.query(
"SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')",
Expand All @@ -82,7 +85,7 @@ async fn prune_timestamps_millis() {

#[tokio::test]
async fn prune_timestamps_seconds() {
let output = ContextWithParquet::new()
let output = ContextWithParquet::new(Scenario::Timestamps)
.await
.query(
"SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')",
Expand All @@ -95,15 +98,60 @@ async fn prune_timestamps_seconds() {
assert_eq!(output.result_rows, 10, "{}", output.description());
}

#[tokio::test]
async fn prune_date32() {
let output = ContextWithParquet::new(Scenario::Dates)
.await
.query("SELECT * FROM t where date32 < cast('2020-01-02' as date)")
.await;
println!("{}", output.description());
// This should prune out groups without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(3));
assert_eq!(output.result_rows, 1, "{}", output.description());
}

#[tokio::test]
async fn prune_date64() {
// work around for not being able to cast Date32 to Date64 automatically
let date = "2020-01-02"
.parse::<chrono::NaiveDate>()
.unwrap()
.and_time(chrono::NaiveTime::from_hms(0, 0, 0));
let date = ScalarValue::Date64(Some(date.timestamp_millis()));

let output = ContextWithParquet::new(Scenario::Dates)
.await
.query_with_expr(col("date64").lt(lit(date)))
// .query(
// "SELECT * FROM t where date64 < caste('2020-01-02' as date)",
// query results in Plan("'Date64 < Date32' can't be evaluated because there isn't a common type to coerce the types to")
// )
.await;

println!("{}", output.description());
// This should prune out groups without error
assert_eq!(output.predicate_evaluation_errors(), Some(0));
assert_eq!(output.row_groups_pruned(), Some(3));
assert_eq!(output.result_rows, 1, "{}", output.description());
}

// ----------------------
// Begin test fixture
// ----------------------

/// What data to use
enum Scenario {
Timestamps,
Dates,
}

/// Test fixture that has an execution context that has an external
/// table "t" registered, pointing at a parquet file made with
/// `make_test_file`
struct ContextWithParquet {
file: NamedTempFile,
provider: Arc<dyn TableProvider>,
ctx: ExecutionContext,
}

Expand Down Expand Up @@ -156,24 +204,54 @@ impl TestOutput {

/// Creates an execution context that has an external table "t"
/// registered pointing at a parquet file made with `make_test_file`
/// and the appropriate scenario
impl ContextWithParquet {
async fn new() -> Self {
let file = make_test_file().await;
async fn new(scenario: Scenario) -> Self {
let file = make_test_file(scenario).await;

// now, setup a the file as a data source and run a query against it
let mut ctx = ExecutionContext::new();
let parquet_path = file.path().to_string_lossy();
ctx.register_parquet("t", &parquet_path)
.expect("registering");

Self { file, ctx }
let table = ParquetTable::try_new(parquet_path, 4).unwrap();

let provider = Arc::new(table);
ctx.register_table("t", provider.clone()).unwrap();

Self {
file,
provider,
ctx,
}
}

/// runs a query like "SELECT * from t WHERE <expr> and returns
/// the number of output rows and normalized execution metrics
async fn query_with_expr(&mut self, expr: Expr) -> TestOutput {
let sql = format!("EXPR only: {:?}", expr);
let logical_plan = LogicalPlanBuilder::scan("t", self.provider.clone(), None)
.unwrap()
.filter(expr)
.unwrap()
.build()
.unwrap();
self.run_test(logical_plan, sql).await
}

/// Runs the specified SQL query and returns the number of output
/// rows and normalized execution metrics
async fn query(&mut self, sql: &str) -> TestOutput {
println!("Planning sql {}", sql);
let logical_plan = self.ctx.sql(sql).expect("planning").to_logical_plan();
self.run_test(logical_plan, sql).await
}

/// runs the logical plan
async fn run_test(
&mut self,
logical_plan: LogicalPlan,
sql: impl Into<String>,
) -> TestOutput {
let input = self
.ctx
.sql("SELECT * from t")
Expand All @@ -183,8 +261,6 @@ impl ContextWithParquet {
.expect("getting input");
let pretty_input = pretty_format_batches(&input).unwrap();

let logical_plan = self.ctx.sql(sql).expect("planning").to_logical_plan();

let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing plan");
let execution_plan = self
.ctx
Expand All @@ -210,7 +286,7 @@ impl ContextWithParquet {

let pretty_results = pretty_format_batches(&results).unwrap();

let sql = sql.to_string();
let sql = sql.into();
TestOutput {
sql,
metrics,
Expand All @@ -222,7 +298,7 @@ impl ContextWithParquet {
}

/// Create a test parquet file with varioud data types
async fn make_test_file() -> NamedTempFile {
async fn make_test_file(scenario: Scenario) -> NamedTempFile {
let output_file = tempfile::Builder::new()
.prefix("parquet_pruning")
.suffix(".parquet")
Expand All @@ -233,12 +309,25 @@ async fn make_test_file() -> NamedTempFile {
.set_max_row_group_size(5)
.build();

let batches = vec![
make_batch(Duration::seconds(0)),
make_batch(Duration::seconds(10)),
make_batch(Duration::minutes(10)),
make_batch(Duration::days(10)),
];
let batches = match scenario {
Scenario::Timestamps => {
vec![
make_timestamp_batch(Duration::seconds(0)),
make_timestamp_batch(Duration::seconds(10)),
make_timestamp_batch(Duration::minutes(10)),
make_timestamp_batch(Duration::days(10)),
]
}
Scenario::Dates => {
vec![
make_date_batch(Duration::days(0)),
make_date_batch(Duration::days(10)),
make_date_batch(Duration::days(300)),
make_date_batch(Duration::days(3600)),
]
}
};

let schema = batches[0].schema();

let mut writer = ArrowWriter::try_new(
Expand Down Expand Up @@ -268,7 +357,7 @@ async fn make_test_file() -> NamedTempFile {
/// "millis" --> TimestampMillisecondArray
/// "seconds" --> TimestampSecondArray
/// "names" --> StringArray
pub fn make_batch(offset: Duration) -> RecordBatch {
fn make_timestamp_batch(offset: Duration) -> RecordBatch {
let ts_strings = vec![
Some("2020-01-01T01:01:01.0000000000001"),
Some("2020-01-01T01:02:01.0000000000001"),
Expand Down Expand Up @@ -341,3 +430,78 @@ pub fn make_batch(offset: Duration) -> RecordBatch {
)
.unwrap()
}

/// Return record batch with a few rows of data for all of the supported date
/// types with the specified offset (in days)
///
/// Columns are named:
/// "date32" --> Date32Array
/// "date64" --> Date64Array
/// "names" --> StringArray
fn make_date_batch(offset: Duration) -> RecordBatch {
let date_strings = vec![
Some("2020-01-01"),
Some("2020-01-02"),
Some("2020-01-03"),
None,
Some("2020-01-04"),
];

let names = date_strings
.iter()
.enumerate()
.map(|(i, val)| format!("Row {} + {}: {:?}", i, offset, val))
.collect::<Vec<_>>();

// Copied from `cast.rs` cast kernel due to lack of temporal kernels
// https://github.com/apache/arrow-rs/issues/527
const EPOCH_DAYS_FROM_CE: i32 = 719_163;

let date_seconds = date_strings
.iter()
.map(|t| {
t.map(|t| {
let t = t.parse::<chrono::NaiveDate>().unwrap();
let t = t + offset;
t.num_days_from_ce() - EPOCH_DAYS_FROM_CE
})
})
.collect::<Vec<_>>();

let date_millis = date_strings
.into_iter()
.map(|t| {
t.map(|t| {
let t = t
.parse::<chrono::NaiveDate>()
.unwrap()
.and_time(chrono::NaiveTime::from_hms(0, 0, 0));
let t = t + offset;
t.timestamp_millis()
})
})
.collect::<Vec<_>>();

let arr_date32 = Date32Array::from(date_seconds);
let arr_date64 = Date64Array::from(date_millis);

let names = names.iter().map(|s| s.as_str()).collect::<Vec<_>>();
let arr_names = StringArray::from(names);

let schema = Schema::new(vec![
Field::new("date32", arr_date32.data_type().clone(), true),
Field::new("date64", arr_date64.data_type().clone(), true),
Field::new("name", arr_names.data_type().clone(), true),
]);
let schema = Arc::new(schema);

RecordBatch::try_new(
schema,
vec![
Arc::new(arr_date32),
Arc::new(arr_date64),
Arc::new(arr_names),
],
)
.unwrap()
}