Skip to content

Commit 5cdb7a2

Browse files
Include full DataType in TopKAggregateStream results
1 parent 677f723 commit 5cdb7a2

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

datafusion/physical-plan/src/aggregates/topk/hash_table.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ where
9999
owned: ArrayRef,
100100
map: TopKHashTable<Option<VAL::Native>>,
101101
rnd: RandomState,
102+
kt: DataType,
102103
}
103104

104105
impl StringHashTable {
@@ -216,12 +217,17 @@ where
216217
Option<<VAL as ArrowPrimitiveType>::Native>: Comparable,
217218
Option<<VAL as ArrowPrimitiveType>::Native>: HashValue,
218219
{
219-
pub fn new(limit: usize) -> Self {
220-
let owned = Arc::new(PrimitiveArray::<VAL>::builder(0).finish());
220+
pub fn new(limit: usize, kt: DataType) -> Self {
221+
let owned = Arc::new(
222+
PrimitiveArray::<VAL>::builder(0)
223+
.with_data_type(kt.clone())
224+
.finish(),
225+
);
221226
Self {
222227
owned,
223228
map: TopKHashTable::new(limit, limit * 10),
224229
rnd: RandomState::default(),
230+
kt,
225231
}
226232
}
227233
}
@@ -249,7 +255,8 @@ where
249255

250256
unsafe fn take_all(&mut self, indexes: Vec<usize>) -> ArrayRef {
251257
let ids = self.map.take_all(indexes);
252-
let mut builder: PrimitiveBuilder<VAL> = PrimitiveArray::builder(ids.len());
258+
let mut builder: PrimitiveBuilder<VAL> =
259+
PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone());
253260
for id in ids.into_iter() {
254261
match id {
255262
None => builder.append_null(),
@@ -413,7 +420,7 @@ pub fn new_hash_table(
413420
) -> Result<Box<dyn ArrowHashTable + Send>> {
414421
macro_rules! downcast_helper {
415422
($kt:ty, $d:ident) => {
416-
return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit)))
423+
return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit, kt)))
417424
};
418425
}
419426

@@ -433,8 +440,27 @@ pub fn new_hash_table(
433440
#[cfg(test)]
434441
mod tests {
435442
use super::*;
443+
use arrow::array::TimestampMillisecondArray;
444+
use arrow_schema::TimeUnit;
436445
use std::collections::BTreeMap;
437446

447+
#[test]
448+
fn should_emit_correct_type() -> Result<()> {
449+
let ids =
450+
TimestampMillisecondArray::from(vec![1000]).with_timezone("UTC".to_string());
451+
let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into()));
452+
let mut ht = new_hash_table(1, dt.clone())?;
453+
ht.set_batch(Arc::new(ids));
454+
let mut mapper = vec![];
455+
let ids = unsafe {
456+
ht.find_or_insert(0, 0, &mut mapper);
457+
ht.take_all(vec![0])
458+
};
459+
assert_eq!(ids.data_type(), &dt);
460+
461+
Ok(())
462+
}
463+
438464
#[test]
439465
fn should_resize_properly() -> Result<()> {
440466
let mut heap_to_map = BTreeMap::<usize, usize>::new();

0 commit comments

Comments
 (0)