Skip to content

Commit 169701e

Browse files
authored
Optimize date_bin (2x faster) (#10215)
* add date_bin benchmark * optimize date_bin As mentioned in the docs for `PrimaryArray::unary` it is faster to apply an infallible operation across both valid and invalid values, rather than branching at every value. 1) Make stride function infallible 2) Use `unary` method This gives this speedup on my machine: Before: 22.345 µs After: 10.558 µs So around 2x faster
1 parent b9f17b0 commit 169701e

File tree

3 files changed

+71
-10
lines changed

3 files changed

+71
-10
lines changed

datafusion/functions/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ harness = false
112112
name = "make_date"
113113
required-features = ["datetime_expressions"]
114114

115+
[[bench]]
116+
harness = false
117+
name = "date_bin"
118+
required-features = ["datetime_expressions"]
119+
115120
[[bench]]
116121
harness = false
117122
name = "to_char"
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
extern crate criterion;
19+
20+
use std::sync::Arc;
21+
22+
use arrow::array::{ArrayRef, TimestampSecondArray};
23+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
24+
use datafusion_common::ScalarValue;
25+
use rand::rngs::ThreadRng;
26+
use rand::Rng;
27+
28+
use datafusion_expr::ColumnarValue;
29+
use datafusion_functions::datetime::date_bin;
30+
31+
fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray {
32+
let mut seconds = vec![];
33+
for _ in 0..1000 {
34+
seconds.push(rng.gen_range(0..1_000_000));
35+
}
36+
37+
TimestampSecondArray::from(seconds)
38+
}
39+
40+
fn criterion_benchmark(c: &mut Criterion) {
41+
c.bench_function("date_bin_1000", |b| {
42+
let mut rng = rand::thread_rng();
43+
let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000));
44+
let timestamps = ColumnarValue::Array(Arc::new(timestamps(&mut rng)) as ArrayRef);
45+
let udf = date_bin();
46+
47+
b.iter(|| {
48+
black_box(
49+
udf.invoke(&[interval.clone(), timestamps.clone()])
50+
.expect("date_bin should work on valid values"),
51+
)
52+
})
53+
});
54+
}
55+
56+
criterion_group!(benches, criterion_benchmark);
57+
criterion_main!(benches);

datafusion/functions/src/datetime/date_bin.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -320,46 +320,46 @@ fn date_bin_impl(
320320
origin: i64,
321321
stride: i64,
322322
stride_fn: fn(i64, i64, i64) -> i64,
323-
) -> impl Fn(Option<i64>) -> Option<i64> {
323+
) -> impl Fn(i64) -> i64 {
324324
let scale = match T::UNIT {
325325
Nanosecond => 1,
326326
Microsecond => NANOSECONDS / 1_000_000,
327327
Millisecond => NANOSECONDS / 1_000,
328328
Second => NANOSECONDS,
329329
};
330-
move |x: Option<i64>| x.map(|x| stride_fn(stride, x * scale, origin) / scale)
330+
move |x: i64| stride_fn(stride, x * scale, origin) / scale
331331
}
332332

333333
Ok(match array {
334334
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
335335
let apply_stride_fn =
336336
stride_map_fn::<TimestampNanosecondType>(origin, stride, stride_fn);
337337
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(
338-
apply_stride_fn(*v),
338+
v.map(apply_stride_fn),
339339
tz_opt.clone(),
340340
))
341341
}
342342
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => {
343343
let apply_stride_fn =
344344
stride_map_fn::<TimestampMicrosecondType>(origin, stride, stride_fn);
345345
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
346-
apply_stride_fn(*v),
346+
v.map(apply_stride_fn),
347347
tz_opt.clone(),
348348
))
349349
}
350350
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => {
351351
let apply_stride_fn =
352352
stride_map_fn::<TimestampMillisecondType>(origin, stride, stride_fn);
353353
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(
354-
apply_stride_fn(*v),
354+
v.map(apply_stride_fn),
355355
tz_opt.clone(),
356356
))
357357
}
358358
ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => {
359359
let apply_stride_fn =
360360
stride_map_fn::<TimestampSecondType>(origin, stride, stride_fn);
361361
ColumnarValue::Scalar(ScalarValue::TimestampSecond(
362-
apply_stride_fn(*v),
362+
v.map(apply_stride_fn),
363363
tz_opt.clone(),
364364
))
365365
}
@@ -377,14 +377,13 @@ fn date_bin_impl(
377377
{
378378
let array = as_primitive_array::<T>(array)?;
379379
let apply_stride_fn = stride_map_fn::<T>(origin, stride, stride_fn);
380-
let array = array
381-
.iter()
382-
.map(apply_stride_fn)
383-
.collect::<PrimitiveArray<T>>()
380+
let array: PrimitiveArray<T> = array
381+
.unary(apply_stride_fn)
384382
.with_timezone_opt(tz_opt.clone());
385383

386384
Ok(ColumnarValue::Array(Arc::new(array)))
387385
}
386+
388387
match array.data_type() {
389388
Timestamp(Nanosecond, tz_opt) => {
390389
transform_array_with_stride::<TimestampNanosecondType>(

0 commit comments

Comments
 (0)