@@ -21,25 +21,28 @@ use std::sync::Arc;
2121
2222use arrow:: {
2323 array:: {
24- Array , StringArray , TimestampMicrosecondArray , TimestampMillisecondArray ,
25- TimestampNanosecondArray , TimestampSecondArray ,
24+ Array , Date32Array , Date64Array , StringArray , TimestampMicrosecondArray ,
25+ TimestampMillisecondArray , TimestampNanosecondArray , TimestampSecondArray ,
2626 } ,
2727 datatypes:: { Field , Schema } ,
2828 record_batch:: RecordBatch ,
2929 util:: pretty:: pretty_format_batches,
3030} ;
31- use chrono:: Duration ;
31+ use chrono:: { Datelike , Duration } ;
3232use datafusion:: {
33+ datasource:: { parquet:: ParquetTable , TableProvider } ,
34+ logical_plan:: { col, lit, Expr , LogicalPlan , LogicalPlanBuilder } ,
3335 physical_plan:: { plan_metrics, SQLMetric } ,
3436 prelude:: ExecutionContext ,
37+ scalar:: ScalarValue ,
3538} ;
3639use hashbrown:: HashMap ;
3740use parquet:: { arrow:: ArrowWriter , file:: properties:: WriterProperties } ;
3841use tempfile:: NamedTempFile ;
3942
4043#[ tokio:: test]
4144async fn prune_timestamps_nanos ( ) {
42- let output = ContextWithParquet :: new ( )
45+ let output = ContextWithParquet :: new ( Scenario :: Timestamps )
4346 . await
4447 . query ( "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')" )
4548 . await ;
@@ -52,7 +55,7 @@ async fn prune_timestamps_nanos() {
5255
5356#[ tokio:: test]
5457async fn prune_timestamps_micros ( ) {
55- let output = ContextWithParquet :: new ( )
58+ let output = ContextWithParquet :: new ( Scenario :: Timestamps )
5659 . await
5760 . query (
5861 "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')" ,
@@ -67,7 +70,7 @@ async fn prune_timestamps_micros() {
6770
6871#[ tokio:: test]
6972async fn prune_timestamps_millis ( ) {
70- let output = ContextWithParquet :: new ( )
73+ let output = ContextWithParquet :: new ( Scenario :: Timestamps )
7174 . await
7275 . query (
7376 "SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')" ,
@@ -82,7 +85,7 @@ async fn prune_timestamps_millis() {
8285
8386#[ tokio:: test]
8487async fn prune_timestamps_seconds ( ) {
85- let output = ContextWithParquet :: new ( )
88+ let output = ContextWithParquet :: new ( Scenario :: Timestamps )
8689 . await
8790 . query (
8891 "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')" ,
@@ -95,15 +98,60 @@ async fn prune_timestamps_seconds() {
9598 assert_eq ! ( output. result_rows, 10 , "{}" , output. description( ) ) ;
9699}
97100
101+ #[ tokio:: test]
102+ async fn prune_date32 ( ) {
103+ let output = ContextWithParquet :: new ( Scenario :: Dates )
104+ . await
105+ . query ( "SELECT * FROM t where date32 < cast('2020-01-02' as date)" )
106+ . await ;
107+ println ! ( "{}" , output. description( ) ) ;
108+ // This should prune out groups without error
109+ assert_eq ! ( output. predicate_evaluation_errors( ) , Some ( 0 ) ) ;
110+ assert_eq ! ( output. row_groups_pruned( ) , Some ( 3 ) ) ;
111+ assert_eq ! ( output. result_rows, 1 , "{}" , output. description( ) ) ;
112+ }
113+
114+ #[ tokio:: test]
115+ async fn prune_date64 ( ) {
116+ // work around for not being able to cast Date32 to Date64 automatically
117+ let date = "2020-01-02"
118+ . parse :: < chrono:: NaiveDate > ( )
119+ . unwrap ( )
120+ . and_time ( chrono:: NaiveTime :: from_hms ( 0 , 0 , 0 ) ) ;
121+ let date = ScalarValue :: Date64 ( Some ( date. timestamp_millis ( ) ) ) ;
122+
123+ let output = ContextWithParquet :: new ( Scenario :: Dates )
124+ . await
125+ . query_with_expr ( col ( "date64" ) . lt ( lit ( date) ) )
126+ // .query(
127+ // "SELECT * FROM t where date64 < caste('2020-01-02' as date)",
128+ // query results in Plan("'Date64 < Date32' can't be evaluated because there isn't a common type to coerce the types to")
129+ // )
130+ . await ;
131+
132+ println ! ( "{}" , output. description( ) ) ;
133+ // This should prune out groups without error
134+ assert_eq ! ( output. predicate_evaluation_errors( ) , Some ( 0 ) ) ;
135+ assert_eq ! ( output. row_groups_pruned( ) , Some ( 3 ) ) ;
136+ assert_eq ! ( output. result_rows, 1 , "{}" , output. description( ) ) ;
137+ }
138+
98139// ----------------------
99140// Begin test fixture
100141// ----------------------
101142
143+ /// What data to use
144+ enum Scenario {
145+ Timestamps ,
146+ Dates ,
147+ }
148+
102149/// Test fixture that has an execution context that has an external
103150/// table "t" registered, pointing at a parquet file made with
104151/// `make_test_file`
105152struct ContextWithParquet {
106153 file : NamedTempFile ,
154+ provider : Arc < dyn TableProvider > ,
107155 ctx : ExecutionContext ,
108156}
109157
@@ -156,24 +204,54 @@ impl TestOutput {
156204
157205/// Creates an execution context that has an external table "t"
158206/// registered pointing at a parquet file made with `make_test_file`
207+ /// and the appropriate scenario
159208impl ContextWithParquet {
160- async fn new ( ) -> Self {
161- let file = make_test_file ( ) . await ;
209+ async fn new ( scenario : Scenario ) -> Self {
210+ let file = make_test_file ( scenario ) . await ;
162211
163212 // now, setup a the file as a data source and run a query against it
164213 let mut ctx = ExecutionContext :: new ( ) ;
165214 let parquet_path = file. path ( ) . to_string_lossy ( ) ;
166- ctx. register_parquet ( "t" , & parquet_path)
167- . expect ( "registering" ) ;
168215
169- Self { file, ctx }
216+ let table = ParquetTable :: try_new ( parquet_path, 4 ) . unwrap ( ) ;
217+
218+ let provider = Arc :: new ( table) ;
219+ ctx. register_table ( "t" , provider. clone ( ) ) . unwrap ( ) ;
220+
221+ Self {
222+ file,
223+ provider,
224+ ctx,
225+ }
226+ }
227+
228+ /// runs a query like "SELECT * from t WHERE <expr> and returns
229+ /// the number of output rows and normalized execution metrics
230+ async fn query_with_expr ( & mut self , expr : Expr ) -> TestOutput {
231+ let sql = format ! ( "EXPR only: {:?}" , expr) ;
232+ let logical_plan = LogicalPlanBuilder :: scan ( "t" , self . provider . clone ( ) , None )
233+ . unwrap ( )
234+ . filter ( expr)
235+ . unwrap ( )
236+ . build ( )
237+ . unwrap ( ) ;
238+ self . run_test ( logical_plan, sql) . await
170239 }
171240
172241 /// Runs the specified SQL query and returns the number of output
173242 /// rows and normalized execution metrics
174243 async fn query ( & mut self , sql : & str ) -> TestOutput {
175244 println ! ( "Planning sql {}" , sql) ;
245+ let logical_plan = self . ctx . sql ( sql) . expect ( "planning" ) . to_logical_plan ( ) ;
246+ self . run_test ( logical_plan, sql) . await
247+ }
176248
249+ /// runs the logical plan
250+ async fn run_test (
251+ & mut self ,
252+ logical_plan : LogicalPlan ,
253+ sql : impl Into < String > ,
254+ ) -> TestOutput {
177255 let input = self
178256 . ctx
179257 . sql ( "SELECT * from t" )
@@ -183,8 +261,6 @@ impl ContextWithParquet {
183261 . expect ( "getting input" ) ;
184262 let pretty_input = pretty_format_batches ( & input) . unwrap ( ) ;
185263
186- let logical_plan = self . ctx . sql ( sql) . expect ( "planning" ) . to_logical_plan ( ) ;
187-
188264 let logical_plan = self . ctx . optimize ( & logical_plan) . expect ( "optimizing plan" ) ;
189265 let execution_plan = self
190266 . ctx
@@ -210,7 +286,7 @@ impl ContextWithParquet {
210286
211287 let pretty_results = pretty_format_batches ( & results) . unwrap ( ) ;
212288
213- let sql = sql. to_string ( ) ;
289+ let sql = sql. into ( ) ;
214290 TestOutput {
215291 sql,
216292 metrics,
@@ -222,7 +298,7 @@ impl ContextWithParquet {
222298}
223299
224300/// Create a test parquet file with varioud data types
225- async fn make_test_file ( ) -> NamedTempFile {
301+ async fn make_test_file ( scenario : Scenario ) -> NamedTempFile {
226302 let output_file = tempfile:: Builder :: new ( )
227303 . prefix ( "parquet_pruning" )
228304 . suffix ( ".parquet" )
@@ -233,12 +309,25 @@ async fn make_test_file() -> NamedTempFile {
233309 . set_max_row_group_size ( 5 )
234310 . build ( ) ;
235311
236- let batches = vec ! [
237- make_batch( Duration :: seconds( 0 ) ) ,
238- make_batch( Duration :: seconds( 10 ) ) ,
239- make_batch( Duration :: minutes( 10 ) ) ,
240- make_batch( Duration :: days( 10 ) ) ,
241- ] ;
312+ let batches = match scenario {
313+ Scenario :: Timestamps => {
314+ vec ! [
315+ make_timestamp_batch( Duration :: seconds( 0 ) ) ,
316+ make_timestamp_batch( Duration :: seconds( 10 ) ) ,
317+ make_timestamp_batch( Duration :: minutes( 10 ) ) ,
318+ make_timestamp_batch( Duration :: days( 10 ) ) ,
319+ ]
320+ }
321+ Scenario :: Dates => {
322+ vec ! [
323+ make_date_batch( Duration :: days( 0 ) ) ,
324+ make_date_batch( Duration :: days( 10 ) ) ,
325+ make_date_batch( Duration :: days( 300 ) ) ,
326+ make_date_batch( Duration :: days( 3600 ) ) ,
327+ ]
328+ }
329+ } ;
330+
242331 let schema = batches[ 0 ] . schema ( ) ;
243332
244333 let mut writer = ArrowWriter :: try_new (
@@ -268,7 +357,7 @@ async fn make_test_file() -> NamedTempFile {
268357/// "millis" --> TimestampMillisecondArray
269358/// "seconds" --> TimestampSecondArray
270359/// "names" --> StringArray
271- pub fn make_batch ( offset : Duration ) -> RecordBatch {
360+ fn make_timestamp_batch ( offset : Duration ) -> RecordBatch {
272361 let ts_strings = vec ! [
273362 Some ( "2020-01-01T01:01:01.0000000000001" ) ,
274363 Some ( "2020-01-01T01:02:01.0000000000001" ) ,
@@ -341,3 +430,78 @@ pub fn make_batch(offset: Duration) -> RecordBatch {
341430 )
342431 . unwrap ( )
343432}
433+
434+ /// Return record batch with a few rows of data for all of the supported date
435+ /// types with the specified offset (in days)
436+ ///
437+ /// Columns are named:
438+ /// "date32" --> Date32Array
439+ /// "date64" --> Date64Array
440+ /// "names" --> StringArray
441+ fn make_date_batch ( offset : Duration ) -> RecordBatch {
442+ let date_strings = vec ! [
443+ Some ( "2020-01-01" ) ,
444+ Some ( "2020-01-02" ) ,
445+ Some ( "2020-01-03" ) ,
446+ None ,
447+ Some ( "2020-01-04" ) ,
448+ ] ;
449+
450+ let names = date_strings
451+ . iter ( )
452+ . enumerate ( )
453+ . map ( |( i, val) | format ! ( "Row {} + {}: {:?}" , i, offset, val) )
454+ . collect :: < Vec < _ > > ( ) ;
455+
456+ // Copied from `cast.rs` cast kernel due to lack of temporal kernels
457+ // https://github.com/apache/arrow-rs/issues/527
458+ const EPOCH_DAYS_FROM_CE : i32 = 719_163 ;
459+
460+ let date_seconds = date_strings
461+ . iter ( )
462+ . map ( |t| {
463+ t. map ( |t| {
464+ let t = t. parse :: < chrono:: NaiveDate > ( ) . unwrap ( ) ;
465+ let t = t + offset;
466+ t. num_days_from_ce ( ) - EPOCH_DAYS_FROM_CE
467+ } )
468+ } )
469+ . collect :: < Vec < _ > > ( ) ;
470+
471+ let date_millis = date_strings
472+ . into_iter ( )
473+ . map ( |t| {
474+ t. map ( |t| {
475+ let t = t
476+ . parse :: < chrono:: NaiveDate > ( )
477+ . unwrap ( )
478+ . and_time ( chrono:: NaiveTime :: from_hms ( 0 , 0 , 0 ) ) ;
479+ let t = t + offset;
480+ t. timestamp_millis ( )
481+ } )
482+ } )
483+ . collect :: < Vec < _ > > ( ) ;
484+
485+ let arr_date32 = Date32Array :: from ( date_seconds) ;
486+ let arr_date64 = Date64Array :: from ( date_millis) ;
487+
488+ let names = names. iter ( ) . map ( |s| s. as_str ( ) ) . collect :: < Vec < _ > > ( ) ;
489+ let arr_names = StringArray :: from ( names) ;
490+
491+ let schema = Schema :: new ( vec ! [
492+ Field :: new( "date32" , arr_date32. data_type( ) . clone( ) , true ) ,
493+ Field :: new( "date64" , arr_date64. data_type( ) . clone( ) , true ) ,
494+ Field :: new( "name" , arr_names. data_type( ) . clone( ) , true ) ,
495+ ] ) ;
496+ let schema = Arc :: new ( schema) ;
497+
498+ RecordBatch :: try_new (
499+ schema,
500+ vec ! [
501+ Arc :: new( arr_date32) ,
502+ Arc :: new( arr_date64) ,
503+ Arc :: new( arr_names) ,
504+ ] ,
505+ )
506+ . unwrap ( )
507+ }
0 commit comments