Skip to content

Commit 7fd6a49

Browse files
author
Jiayu Liu
committed
use i64
1 parent 46d892a commit 7fd6a49

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

datafusion/src/physical_plan/expressions/nth_value.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl BuiltInWindowFunctionExpr for FirstValue {
7272
}
7373

7474
// sql values start with 1, so we can use 0 to indicate the special last value behavior
75-
const SPECIAL_USIZE_VALUE_FOR_LAST: usize = 0;
75+
const SPECIAL_SIZE_VALUE_FOR_LAST: u32 = 0;
7676

7777
/// last_value expression
7878
#[derive(Debug)]
@@ -114,7 +114,7 @@ impl BuiltInWindowFunctionExpr for LastValue {
114114

115115
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
116116
Ok(Box::new(NthValueAccumulator::try_new(
117-
SPECIAL_USIZE_VALUE_FOR_LAST,
117+
SPECIAL_SIZE_VALUE_FOR_LAST,
118118
self.data_type.clone(),
119119
)?))
120120
}
@@ -124,7 +124,7 @@ impl BuiltInWindowFunctionExpr for LastValue {
124124
#[derive(Debug)]
125125
pub struct NthValue {
126126
name: String,
127-
n: usize,
127+
n: u32,
128128
data_type: DataType,
129129
expr: Arc<dyn PhysicalExpr>,
130130
}
@@ -134,10 +134,10 @@ impl NthValue {
134134
pub fn try_new(
135135
expr: Arc<dyn PhysicalExpr>,
136136
name: String,
137-
n: usize,
137+
n: u32,
138138
data_type: DataType,
139139
) -> Result<Self> {
140-
if n == SPECIAL_USIZE_VALUE_FOR_LAST {
140+
if n == SPECIAL_SIZE_VALUE_FOR_LAST {
141141
Err(DataFusionError::Execution(
142142
"nth_value expect n to be > 0".to_owned(),
143143
))
@@ -184,14 +184,14 @@ struct NthValueAccumulator {
184184
// n the target nth_value, however we'll reuse it for last_value acc, so when n == 0 it specifically
185185
// means last; also note that it is totally valid for n to be larger than the number of rows input
186186
// in which case all the values shall be null
187-
n: usize,
188-
offset: usize,
187+
n: u32,
188+
offset: u32,
189189
value: ScalarValue,
190190
}
191191

192192
impl NthValueAccumulator {
193193
/// new count accumulator
194-
pub fn try_new(n: usize, data_type: DataType) -> Result<Self> {
194+
pub fn try_new(n: u32, data_type: DataType) -> Result<Self> {
195195
Ok(Self {
196196
n,
197197
offset: 0,
@@ -203,7 +203,7 @@ impl NthValueAccumulator {
203203

204204
impl WindowAccumulator for NthValueAccumulator {
205205
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
206-
if self.n == SPECIAL_USIZE_VALUE_FOR_LAST {
206+
if self.n == SPECIAL_SIZE_VALUE_FOR_LAST {
207207
// for last_value function
208208
self.value = values[0].clone();
209209
} else if self.offset < self.n {
@@ -212,10 +212,10 @@ impl WindowAccumulator for NthValueAccumulator {
212212
self.value = values[0].clone();
213213
}
214214
}
215-
Ok(Some(self.value.clone()))
215+
Ok(None)
216216
}
217217

218218
fn evaluate(&self) -> Result<Option<ScalarValue>> {
219-
Ok(None)
219+
Ok(Some(self.value.clone()))
220220
}
221221
}

datafusion/src/physical_plan/windows.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
//! Execution plan for window functions
1919
2020
use crate::error::{DataFusionError, Result};
21-
use crate::physical_plan::type_coercion::coerce;
2221
use crate::physical_plan::{
2322
aggregates,
24-
expressions::{FirstValue, LastValue, RowNumber},
23+
expressions::{FirstValue, LastValue, Literal, NthValue, RowNumber},
24+
type_coercion::coerce,
2525
window_functions::{signature_for_built_in, BuiltInWindowFunction, WindowFunction},
2626
Accumulator, AggregateExpr, BuiltInWindowFunctionExpr, Distribution, ExecutionPlan,
2727
Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream,
@@ -40,6 +40,7 @@ use futures::stream::{Stream, StreamExt};
4040
use futures::Future;
4141
use pin_project_lite::pin_project;
4242
use std::any::Any;
43+
use std::convert::TryInto;
4344
use std::iter;
4445
use std::pin::Pin;
4546
use std::sync::Arc;
@@ -89,6 +90,22 @@ fn create_built_in_window_expr(
8990
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
9091
match fun {
9192
BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))),
93+
BuiltInWindowFunction::NthValue => {
94+
let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
95+
let arg = coerced_args[0].clone();
96+
let n = coerced_args[1]
97+
.as_any()
98+
.downcast_ref::<Literal>()
99+
.unwrap()
100+
.value();
101+
let n: i64 = n
102+
.clone()
103+
.try_into()
104+
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
105+
let n: u32 = n as u32;
106+
let data_type = args[0].data_type(input_schema)?;
107+
Ok(Arc::new(NthValue::try_new(arg, name, n, data_type)?))
108+
}
92109
BuiltInWindowFunction::FirstValue => {
93110
let arg =
94111
coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();

0 commit comments

Comments
 (0)