Skip to content

Commit 002ca5d

Browse files
authored
Lead/lag window function with offset and default value arguments (#687)
1 parent fd50dd8 commit 002ca5d

File tree

6 files changed

+221
-18
lines changed

6 files changed

+221
-18
lines changed

datafusion/src/physical_plan/expressions/lead_lag.rs

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
use crate::error::{DataFusionError, Result};
2222
use crate::physical_plan::window_functions::PartitionEvaluator;
2323
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
24+
use crate::scalar::ScalarValue;
2425
use arrow::array::ArrayRef;
25-
use arrow::compute::kernels::window::shift;
26+
use arrow::compute::cast;
2627
use arrow::datatypes::{DataType, Field};
2728
use arrow::record_batch::RecordBatch;
2829
use std::any::Any;
30+
use std::ops::Neg;
2931
use std::ops::Range;
3032
use std::sync::Arc;
3133

@@ -36,19 +38,23 @@ pub struct WindowShift {
3638
data_type: DataType,
3739
shift_offset: i64,
3840
expr: Arc<dyn PhysicalExpr>,
41+
default_value: Option<ScalarValue>,
3942
}
4043

4144
/// lead() window function
4245
pub fn lead(
4346
name: String,
4447
data_type: DataType,
4548
expr: Arc<dyn PhysicalExpr>,
49+
shift_offset: Option<i64>,
50+
default_value: Option<ScalarValue>,
4651
) -> WindowShift {
4752
WindowShift {
4853
name,
4954
data_type,
50-
shift_offset: -1,
55+
shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
5156
expr,
57+
default_value,
5258
}
5359
}
5460

@@ -57,12 +63,15 @@ pub fn lag(
5763
name: String,
5864
data_type: DataType,
5965
expr: Arc<dyn PhysicalExpr>,
66+
shift_offset: Option<i64>,
67+
default_value: Option<ScalarValue>,
6068
) -> WindowShift {
6169
WindowShift {
6270
name,
6371
data_type,
64-
shift_offset: 1,
72+
shift_offset: shift_offset.unwrap_or(1),
6573
expr,
74+
default_value,
6675
}
6776
}
6877

@@ -98,20 +107,71 @@ impl BuiltInWindowFunctionExpr for WindowShift {
98107
Ok(Box::new(WindowShiftEvaluator {
99108
shift_offset: self.shift_offset,
100109
values,
110+
default_value: self.default_value.clone(),
101111
}))
102112
}
103113
}
104114

105115
pub(crate) struct WindowShiftEvaluator {
106116
shift_offset: i64,
107117
values: Vec<ArrayRef>,
118+
default_value: Option<ScalarValue>,
119+
}
120+
121+
fn create_empty_array(
122+
value: &Option<ScalarValue>,
123+
data_type: &DataType,
124+
size: usize,
125+
) -> Result<ArrayRef> {
126+
use arrow::array::new_null_array;
127+
let array = value
128+
.as_ref()
129+
.map(|scalar| scalar.to_array_of_size(size))
130+
.unwrap_or_else(|| new_null_array(data_type, size));
131+
if array.data_type() != data_type {
132+
cast(&array, data_type).map_err(DataFusionError::ArrowError)
133+
} else {
134+
Ok(array)
135+
}
136+
}
137+
138+
// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
139+
fn shift_with_default_value(
140+
array: &ArrayRef,
141+
offset: i64,
142+
value: &Option<ScalarValue>,
143+
) -> Result<ArrayRef> {
144+
use arrow::compute::concat;
145+
146+
let value_len = array.len() as i64;
147+
if offset == 0 {
148+
Ok(arrow::array::make_array(array.data_ref().clone()))
149+
} else if offset == i64::MIN || offset.abs() >= value_len {
150+
create_empty_array(value, array.data_type(), array.len())
151+
} else {
152+
let slice_offset = (-offset).clamp(0, value_len) as usize;
153+
let length = array.len() - offset.abs() as usize;
154+
let slice = array.slice(slice_offset, length);
155+
156+
// Generate array with remaining `null` items
157+
let nulls = offset.abs() as usize;
158+
let default_values = create_empty_array(value, slice.data_type(), nulls)?;
159+
// Concatenate both arrays, add nulls after if shift > 0 else before
160+
if offset > 0 {
161+
concat(&[default_values.as_ref(), slice.as_ref()])
162+
.map_err(DataFusionError::ArrowError)
163+
} else {
164+
concat(&[slice.as_ref(), default_values.as_ref()])
165+
.map_err(DataFusionError::ArrowError)
166+
}
167+
}
108168
}
109169

110170
impl PartitionEvaluator for WindowShiftEvaluator {
111171
fn evaluate_partition(&self, partition: Range<usize>) -> Result<ArrayRef> {
112172
let value = &self.values[0];
113173
let value = value.slice(partition.start, partition.end - partition.start);
114-
shift(value.as_ref(), self.shift_offset).map_err(DataFusionError::ArrowError)
174+
shift_with_default_value(&value, self.shift_offset, &self.default_value)
115175
}
116176
}
117177

@@ -142,6 +202,8 @@ mod tests {
142202
"lead".to_owned(),
143203
DataType::Float32,
144204
Arc::new(Column::new("c3", 0)),
205+
None,
206+
None,
145207
),
146208
vec![
147209
Some(-2),
@@ -162,6 +224,8 @@ mod tests {
162224
"lead".to_owned(),
163225
DataType::Float32,
164226
Arc::new(Column::new("c3", 0)),
227+
None,
228+
None,
165229
),
166230
vec![
167231
None,
@@ -176,6 +240,28 @@ mod tests {
176240
.iter()
177241
.collect::<Int32Array>(),
178242
)?;
243+
244+
test_i32_result(
245+
lag(
246+
"lead".to_owned(),
247+
DataType::Int32,
248+
Arc::new(Column::new("c3", 0)),
249+
None,
250+
Some(ScalarValue::Int32(Some(100))),
251+
),
252+
vec![
253+
Some(100),
254+
Some(1),
255+
Some(-2),
256+
Some(3),
257+
Some(-4),
258+
Some(5),
259+
Some(-6),
260+
Some(7),
261+
]
262+
.iter()
263+
.collect::<Int32Array>(),
264+
)?;
179265
Ok(())
180266
}
181267
}

datafusion/src/physical_plan/type_coercion.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,11 @@ fn get_valid_types(
128128
}
129129
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
130130
}
131-
Signature::OneOf(types) => {
132-
let mut r = vec![];
133-
for s in types {
134-
r.extend(get_valid_types(s, current_types)?);
135-
}
136-
r
137-
}
131+
Signature::OneOf(types) => types
132+
.iter()
133+
.filter_map(|t| get_valid_types(t, current_types).ok())
134+
.flatten()
135+
.collect::<Vec<_>>(),
138136
};
139137

140138
Ok(valid_types)
@@ -367,4 +365,27 @@ mod tests {
367365

368366
Ok(())
369367
}
368+
369+
#[test]
370+
fn test_get_valid_types_one_of() -> Result<()> {
371+
let signature = Signature::OneOf(vec![Signature::Any(1), Signature::Any(2)]);
372+
373+
let invalid_types = get_valid_types(
374+
&signature,
375+
&[DataType::Int32, DataType::Int32, DataType::Int32],
376+
)?;
377+
assert_eq!(invalid_types.len(), 0);
378+
379+
let args = vec![DataType::Int32, DataType::Int32];
380+
let valid_types = get_valid_types(&signature, &args)?;
381+
assert_eq!(valid_types.len(), 1);
382+
assert_eq!(valid_types[0], args);
383+
384+
let args = vec![DataType::Int32];
385+
let valid_types = get_valid_types(&signature, &args)?;
386+
assert_eq!(valid_types.len(), 1);
387+
assert_eq!(valid_types[0], args);
388+
389+
Ok(())
390+
}
370391
}

datafusion/src/physical_plan/window_functions.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,16 @@ pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
201201
| BuiltInWindowFunction::DenseRank
202202
| BuiltInWindowFunction::PercentRank
203203
| BuiltInWindowFunction::CumeDist => Signature::Any(0),
204-
BuiltInWindowFunction::Lag
205-
| BuiltInWindowFunction::Lead
206-
| BuiltInWindowFunction::FirstValue
207-
| BuiltInWindowFunction::LastValue => Signature::Any(1),
204+
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
205+
Signature::OneOf(vec![
206+
Signature::Any(1),
207+
Signature::Any(2),
208+
Signature::Any(3),
209+
])
210+
}
211+
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
212+
Signature::Any(1)
213+
}
208214
BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
209215
BuiltInWindowFunction::NthValue => Signature::Any(2),
210216
}

datafusion/src/physical_plan/windows.rs

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use crate::physical_plan::{
3232
Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
3333
RecordBatchStream, SendableRecordBatchStream, WindowExpr,
3434
};
35+
use crate::scalar::ScalarValue;
3536
use arrow::compute::concat;
3637
use arrow::{
3738
array::ArrayRef,
@@ -96,6 +97,19 @@ pub fn create_window_expr(
9697
})
9798
}
9899

100+
fn get_scalar_value_from_args(
101+
args: &[Arc<dyn PhysicalExpr>],
102+
index: usize,
103+
) -> Option<ScalarValue> {
104+
args.get(index).map(|v| {
105+
v.as_any()
106+
.downcast_ref::<Literal>()
107+
.unwrap()
108+
.value()
109+
.clone()
110+
})
111+
}
112+
99113
fn create_built_in_window_expr(
100114
fun: &BuiltInWindowFunction,
101115
args: &[Arc<dyn PhysicalExpr>],
@@ -110,13 +124,21 @@ fn create_built_in_window_expr(
110124
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
111125
let arg = coerced_args[0].clone();
112126
let data_type = args[0].data_type(input_schema)?;
113-
Arc::new(lag(name, data_type, arg))
127+
let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
128+
.map(|v| v.try_into())
129+
.and_then(|v| v.ok());
130+
let default_value = get_scalar_value_from_args(&coerced_args, 2);
131+
Arc::new(lag(name, data_type, arg, shift_offset, default_value))
114132
}
115133
BuiltInWindowFunction::Lead => {
116134
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
117135
let arg = coerced_args[0].clone();
118136
let data_type = args[0].data_type(input_schema)?;
119-
Arc::new(lead(name, data_type, arg))
137+
let shift_offset = get_scalar_value_from_args(&coerced_args, 1)
138+
.map(|v| v.try_into())
139+
.and_then(|v| v.ok());
140+
let default_value = get_scalar_value_from_args(&coerced_args, 2);
141+
Arc::new(lead(name, data_type, arg, shift_offset, default_value))
120142
}
121143
BuiltInWindowFunction::NthValue => {
122144
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
@@ -592,6 +614,47 @@ mod tests {
592614
Ok((input, schema))
593615
}
594616

617+
#[test]
618+
fn test_create_window_exp_lead_no_args() -> Result<()> {
619+
let (_, schema) = create_test_schema(1)?;
620+
621+
let expr = create_window_expr(
622+
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
623+
"prev".to_owned(),
624+
&[col("c2", &schema)?],
625+
&[],
626+
&[],
627+
Some(WindowFrame::default()),
628+
schema.as_ref(),
629+
)?;
630+
631+
assert_eq!(expr.name(), "prev");
632+
633+
Ok(())
634+
}
635+
636+
#[test]
637+
fn test_create_window_exp_lead_with_args() -> Result<()> {
638+
let (_, schema) = create_test_schema(1)?;
639+
640+
let expr = create_window_expr(
641+
&WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
642+
"prev".to_owned(),
643+
&[
644+
col("c2", &schema)?,
645+
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
646+
],
647+
&[],
648+
&[],
649+
Some(WindowFrame::default()),
650+
schema.as_ref(),
651+
)?;
652+
653+
assert_eq!(expr.name(), "prev");
654+
655+
Ok(())
656+
}
657+
595658
#[tokio::test]
596659
async fn window_function() -> Result<()> {
597660
let (input, schema) = create_test_schema(1)?;
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
17+
SELECT
18+
c8,
19+
LEAD(c8) OVER () next_c8,
20+
LEAD(c8, 10, 10) OVER() next_10_c8,
21+
LEAD(c8, 100, 10) OVER() next_out_of_bounds_c8,
22+
LAG(c8) OVER() prev_c8,
23+
LAG(c8, -2, 0) OVER() AS prev_2_c8,
24+
LAG(c8, -200, 10) OVER() AS prev_out_of_bounds_c8
25+
26+
FROM test
27+
ORDER BY c8;

integration-tests/test_psql_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def generate_csv_from_psql(fname: str):
7777

7878
class TestPsqlParity:
7979
def test_tests_count(self):
80-
assert len(test_files) == 14, "tests are missed"
80+
assert len(test_files) == 15, "tests are missed"
8181

8282
@pytest.mark.parametrize("fname", test_files)
8383
def test_sql_file(self, fname):

0 commit comments

Comments
 (0)