9999 owned : ArrayRef ,
100100 map : TopKHashTable < Option < VAL :: Native > > ,
101101 rnd : RandomState ,
102+ kt : DataType ,
102103}
103104
104105impl StringHashTable {
@@ -216,12 +217,17 @@ where
216217 Option < <VAL as ArrowPrimitiveType >:: Native > : Comparable ,
217218 Option < <VAL as ArrowPrimitiveType >:: Native > : HashValue ,
218219{
219- pub fn new ( limit : usize ) -> Self {
220- let owned = Arc :: new ( PrimitiveArray :: < VAL > :: builder ( 0 ) . finish ( ) ) ;
220+ pub fn new ( limit : usize , kt : DataType ) -> Self {
221+ let owned = Arc :: new (
222+ PrimitiveArray :: < VAL > :: builder ( 0 )
223+ . with_data_type ( kt. clone ( ) )
224+ . finish ( ) ,
225+ ) ;
221226 Self {
222227 owned,
223228 map : TopKHashTable :: new ( limit, limit * 10 ) ,
224229 rnd : RandomState :: default ( ) ,
230+ kt,
225231 }
226232 }
227233}
@@ -249,7 +255,8 @@ where
249255
250256 unsafe fn take_all ( & mut self , indexes : Vec < usize > ) -> ArrayRef {
251257 let ids = self . map . take_all ( indexes) ;
252- let mut builder: PrimitiveBuilder < VAL > = PrimitiveArray :: builder ( ids. len ( ) ) ;
258+ let mut builder: PrimitiveBuilder < VAL > =
259+ PrimitiveArray :: builder ( ids. len ( ) ) . with_data_type ( self . kt . clone ( ) ) ;
253260 for id in ids. into_iter ( ) {
254261 match id {
255262 None => builder. append_null ( ) ,
@@ -413,7 +420,7 @@ pub fn new_hash_table(
413420) -> Result < Box < dyn ArrowHashTable + Send > > {
414421 macro_rules! downcast_helper {
415422 ( $kt: ty, $d: ident) => {
416- return Ok ( Box :: new( PrimitiveHashTable :: <$kt>:: new( limit) ) )
423+ return Ok ( Box :: new( PrimitiveHashTable :: <$kt>:: new( limit, kt ) ) )
417424 } ;
418425 }
419426
@@ -433,8 +440,27 @@ pub fn new_hash_table(
433440#[ cfg( test) ]
434441mod tests {
435442 use super :: * ;
443+ use arrow:: array:: TimestampMillisecondArray ;
444+ use arrow_schema:: TimeUnit ;
436445 use std:: collections:: BTreeMap ;
437446
447+ #[ test]
448+ fn should_emit_correct_type ( ) -> Result < ( ) > {
449+ let ids =
450+ TimestampMillisecondArray :: from ( vec ! [ 1000 ] ) . with_timezone ( "UTC" . to_string ( ) ) ;
451+ let dt = DataType :: Timestamp ( TimeUnit :: Millisecond , Some ( "UTC" . into ( ) ) ) ;
452+ let mut ht = new_hash_table ( 1 , dt. clone ( ) ) ?;
453+ ht. set_batch ( Arc :: new ( ids) ) ;
454+ let mut mapper = vec ! [ ] ;
455+ let ids = unsafe {
456+ ht. find_or_insert ( 0 , 0 , & mut mapper) ;
457+ ht. take_all ( vec ! [ 0 ] )
458+ } ;
459+ assert_eq ! ( ids. data_type( ) , & dt) ;
460+
461+ Ok ( ( ) )
462+ }
463+
438464 #[ test]
439465 fn should_resize_properly ( ) -> Result < ( ) > {
440466 let mut heap_to_map = BTreeMap :: < usize , usize > :: new ( ) ;
0 commit comments