diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index c0f7098a207d..13ecfbb89448 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -99,8 +99,10 @@ impl UnitVec { } } + /// # Panics + /// Panics if `new_cap <= 1` or `new_cap < self.len` fn realloc(&mut self, new_cap: usize) { - assert!(new_cap >= self.len); + assert!(new_cap > 1 && new_cap >= self.len); unsafe { let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(new_cap)); let buffer = me.as_mut_ptr(); @@ -121,9 +123,17 @@ impl UnitVec { } pub fn with_capacity(capacity: usize) -> Self { - let mut new = Self::new(); - new.reserve(capacity); - new + if capacity <= 1 { + Self::new() + } else { + let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(capacity)); + let data = me.as_mut_ptr(); + Self { + len: 0, + capacity: NonZeroUsize::new(capacity).unwrap(), + data, + } + } } #[inline] @@ -178,13 +188,13 @@ impl Drop for UnitVec { impl Clone for UnitVec { fn clone(&self) -> Self { unsafe { - let mut me = std::mem::ManuallyDrop::new(Vec::with_capacity(self.len)); - let buffer = me.as_mut_ptr(); - std::ptr::copy(self.data_ptr(), buffer, self.len); - UnitVec { - data: buffer, - len: self.len, - capacity: NonZeroUsize::new(std::cmp::max(self.len, 1)).unwrap(), + if self.capacity.get() == 1 { + Self { ..*self } + } else { + let mut copy = Self::with_capacity(self.len); + std::ptr::copy(self.data_ptr(), copy.data_ptr_mut(), self.len); + copy.len = self.len; + copy } } } @@ -295,11 +305,57 @@ macro_rules! unitvec { ); ($elem:expr) => ( {let mut new = $crate::idx_vec::UnitVec::new(); + let v = $elem; // SAFETY: first element always fits. - unsafe { new.push_unchecked($elem) }; + unsafe { new.push_unchecked(v) }; new} ); ($($x:expr),+ $(,)?) => ( vec![$($x),+].into() ); } + +mod tests { + + #[test] + #[should_panic] + fn test_unitvec_realloc_zero() { + super::UnitVec::::new().realloc(0); + } + + #[test] + #[should_panic] + fn test_unitvec_realloc_one() { + super::UnitVec::::new().realloc(1); + } + + #[test] + #[should_panic] + fn test_untivec_realloc_lt_len() { + super::UnitVec::::from(&[1, 2][..]).realloc(1) + } + + #[test] + fn test_unitvec_clone() { + { + let v = unitvec![1usize]; + assert_eq!(v, v.clone()); + } + + for n in [ + 26903816120209729usize, + 42566276440897687, + 44435161834424652, + 49390731489933083, + 51201454727649242, + 83861672190814841, + 92169290527847622, + 92476373900398436, + 95488551309275459, + 97499984126814549, + ] { + let v = unitvec![n]; + assert_eq!(v, v.clone()); + } + } +}