Skip to content

Commit

Permalink
fix: Ensure same fmt in Series/AnyValue to string cast (#18982)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 27, 2024
1 parent 6abc2f1 commit fa7ec47
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/compute/cast/primitive_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::offset::{Offset, Offsets};
use crate::temporal_conversions::*;
use crate::types::{days_ms, f16, months_days_ns, NativeType};

pub(super) trait SerPrimitive {
pub trait SerPrimitive {
fn write(f: &mut Vec<u8>, val: Self) -> usize
where
Self: Sized;
Expand Down
17 changes: 10 additions & 7 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use arrow::compute::cast::SerPrimitive;
use arrow::types::PrimitiveType;
use polars_utils::format_pl_smallstr;
#[cfg(feature = "dtype-categorical")]
use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::ToTotalOrd;
Expand Down Expand Up @@ -563,19 +563,22 @@ impl<'a> AnyValue<'a> {
(AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()),

// to string
(AnyValue::String(v), DataType::String) => {
AnyValue::StringOwned(PlSmallStr::from_str(v))
},
(AnyValue::String(v), DataType::String) => AnyValue::String(v),
(AnyValue::StringOwned(v), DataType::String) => AnyValue::StringOwned(v.clone()),

(av, DataType::String) => {
let mut tmp = vec![];
if av.is_unsigned_integer() {
AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::<u64>()?))
let val = av.extract::<u64>()?;
SerPrimitive::write(&mut tmp, val);
} else if av.is_float() {
AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::<f64>()?))
let val = av.extract::<f64>()?;
SerPrimitive::write(&mut tmp, val);
} else {
AnyValue::StringOwned(format_pl_smallstr!("{}", av.extract::<i64>()?))
let val = av.extract::<i64>()?;
SerPrimitive::write(&mut tmp, val);
}
AnyValue::StringOwned(PlSmallStr::from_str(std::str::from_utf8(&tmp).unwrap()))
},

// to binary
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,9 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:
df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})
result = df.select((pl.col("v") < 3).sum().cast(dtype) / pl.len())
assert result.item() - 0.3333333 <= 0.00001


def test_cast_consistency() -> None:
assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(
b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String)
).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}

0 comments on commit fa7ec47

Please sign in to comment.