Skip to content

Commit

Permalink
Push allow_object into AnyValue
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Sep 29, 2024
1 parent 901b243 commit a61f379
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 82 deletions.
74 changes: 44 additions & 30 deletions crates/polars-python/src/conversion/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use polars::prelude::{AnyValue, PlSmallStr, Series, TimeZone};
use polars_core::export::chrono::{NaiveDate, NaiveDateTime, NaiveTime, TimeDelta, Timelike};
use polars_core::utils::any_values_to_supertype_and_n_dtypes;
use polars_core::utils::arrow::temporal_conversions::date32_to_date;
use pyo3::exceptions::{PyOverflowError, PyTypeError};
use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PySequence, PyString, PyTuple};
Expand All @@ -36,7 +36,7 @@ impl ToPyObject for Wrap<AnyValue<'_>> {

impl<'py> FromPyObject<'py> for Wrap<AnyValue<'py>> {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
py_object_to_any_value(ob, true).map(Wrap)
py_object_to_any_value(ob, true, true).map(Wrap)
}
}

Expand Down Expand Up @@ -161,6 +161,7 @@ pub(crate) static LUT: crate::gil_once_cell::GILOnceCell<PlHashMap<TypeObjectPtr
pub(crate) fn py_object_to_any_value<'py>(
ob: &Bound<'py, PyAny>,
strict: bool,
allow_object: bool,
) -> PyResult<AnyValue<'py>> {
// Conversion functions.
fn get_null(_ob: &Bound<'_, PyAny>, _strict: bool) -> PyResult<AnyValue<'static>> {
Expand Down Expand Up @@ -328,7 +329,7 @@ pub(crate) fn py_object_to_any_value<'py>(
let mut items = Vec::with_capacity(INFER_SCHEMA_LENGTH);
for item in (&mut iter).take(INFER_SCHEMA_LENGTH) {
items.push(item?);
let av = py_object_to_any_value(items.last().unwrap(), strict)?;
let av = py_object_to_any_value(items.last().unwrap(), strict, true)?;
avs.push(av)
}
let (dtype, n_dtypes) = any_values_to_supertype_and_n_dtypes(&avs)
Expand All @@ -344,7 +345,7 @@ pub(crate) fn py_object_to_any_value<'py>(
let mut rest = Vec::with_capacity(length);
for item in iter {
rest.push(item?);
let av = py_object_to_any_value(rest.last().unwrap(), strict)?;
let av = py_object_to_any_value(rest.last().unwrap(), strict, true)?;
avs.push(av)
}

Expand Down Expand Up @@ -374,7 +375,7 @@ pub(crate) fn py_object_to_any_value<'py>(
let mut vals = Vec::with_capacity(len);
for (k, v) in dict.into_iter() {
let key = k.extract::<Cow<str>>()?;
let val = py_object_to_any_value(&v, strict)?;
let val = py_object_to_any_value(&v, strict, true)?;
let dtype = val.dtype();
keys.push(Field::new(key.as_ref().into(), dtype));
vals.push(val)
Expand All @@ -399,48 +400,51 @@ pub(crate) fn py_object_to_any_value<'py>(
///
/// Note: This function is only ran if the object's type is not already in the
/// lookup table.
fn get_conversion_function(ob: &Bound<'_, PyAny>, py: Python<'_>) -> InitFn {
fn get_conversion_function(
ob: &Bound<'_, PyAny>,
py: Python<'_>,
allow_object: bool,
) -> PyResult<InitFn> {
if ob.is_none() {
get_null
Ok(get_null)
}
// bool must be checked before int because Python bool is an instance of int.
else if ob.is_instance_of::<PyBool>() {
get_bool
Ok(get_bool)
} else if ob.is_instance_of::<PyInt>() {
get_int
Ok(get_int)
} else if ob.is_instance_of::<PyFloat>() {
get_float
Ok(get_float)
} else if ob.is_instance_of::<PyString>() {
get_str
Ok(get_str)
} else if ob.is_instance_of::<PyBytes>() {
get_bytes
Ok(get_bytes)
} else if ob.is_instance_of::<PyList>() || ob.is_instance_of::<PyTuple>() {
get_list
Ok(get_list)
} else if ob.is_instance_of::<PyDict>() {
get_struct
} else if ob.hasattr(intern!(py, "_s")).unwrap() {
get_list_from_series
Ok(get_struct)
} else {
let type_name = ob.get_type().qualname().unwrap();
let ob_type = ob.get_type();
let type_name = ob_type.qualname().unwrap();
match &*type_name {
// Can't use pyo3::types::PyDateTime with abi3-py37 feature,
// so need this workaround instead of `isinstance(ob, datetime)`.
"date" => get_date as InitFn,
"time" => get_time as InitFn,
"datetime" => get_datetime as InitFn,
"timedelta" => get_timedelta as InitFn,
"Decimal" => get_decimal as InitFn,
"range" => get_list as InitFn,
"date" => Ok(get_date as InitFn),
"time" => Ok(get_time as InitFn),
"datetime" => Ok(get_datetime as InitFn),
"timedelta" => Ok(get_timedelta as InitFn),
"Decimal" => Ok(get_decimal as InitFn),
"range" => Ok(get_list as InitFn),
_ => {
// Support NumPy scalars.
if ob.extract::<i64>().is_ok() || ob.extract::<u64>().is_ok() {
return get_int as InitFn;
return Ok(get_int as InitFn);
} else if ob.extract::<f64>().is_ok() {
return get_float as InitFn;
return Ok(get_float as InitFn);
}

// Support custom subclasses of datetime/date.
let ancestors = ob.get_type().getattr(intern!(py, "__mro__")).unwrap();
let ancestors = ob_type.getattr(intern!(py, "__mro__")).unwrap();
let ancestors_str_iter = ancestors
.iter()
.unwrap()
Expand All @@ -449,13 +453,21 @@ pub(crate) fn py_object_to_any_value<'py>(
match &*c {
// datetime must be checked before date because
// Python datetime is an instance of date.
"<class 'datetime.datetime'>" => return get_datetime as InitFn,
"<class 'datetime.date'>" => return get_date as InitFn,
"<class 'datetime.datetime'>" => {
return Ok(get_datetime as InitFn);
},
"<class 'datetime.date'>" => return Ok(get_date as InitFn),
"<class 'datetime.timedelta'>" => return Ok(get_timedelta as InitFn),
"<class 'datetime.time'>" => return Ok(get_time as InitFn),
_ => (),
}
}

get_object as InitFn
if allow_object {
Ok(get_object as InitFn)
} else {
Err(PyValueError::new_err(format!("Cannot convert {ob}")))
}
},
}
}
Expand All @@ -464,10 +476,12 @@ pub(crate) fn py_object_to_any_value<'py>(
let type_object_ptr = ob.get_type().as_type_ptr() as usize;

Python::with_gil(|py| {
let conversion_function = get_conversion_function(ob, py, allow_object)?;

LUT.with_gil(py, |lut| {
let convert_fn = lut
.entry(type_object_ptr)
.or_insert_with(|| get_conversion_function(ob, py));
.or_insert_with(|| conversion_function);
convert_fn(ob, strict)
})
})
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/dataframe/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ fn dicts_to_rows<'a>(
for k in names.iter() {
let val = match d.get_item(k)? {
None => AnyValue::Null,
Some(val) => py_object_to_any_value(&val.as_borrowed(), strict)?,
Some(val) => py_object_to_any_value(&val.as_borrowed(), strict, true)?,
};
row.push(val)
}
Expand Down
39 changes: 14 additions & 25 deletions crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,37 +460,26 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyR
Ok(dsl::lit(Null {}).into())
} else if let Ok(value) = value.downcast::<PyBytes>() {
Ok(dsl::lit(value.as_bytes()).into())
} else if matches!(
value.get_type().qualname().unwrap().as_str(),
"date" | "datetime" | "time" | "timedelta" | "Decimal"
) {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else {
Python::with_gil(|py| {
// One final attempt before erroring. Do we have a date/datetime subclass?
// E.g. pd.Timestamp, or Freezegun.
let datetime_module = PyModule::import_bound(py, "datetime")?;
let datetime_class = datetime_module.getattr("datetime")?;
let date_class = datetime_module.getattr("date")?;
if value.is_instance(&datetime_class)? || value.is_instance(&date_class)? {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else if allow_object {
let av = py_object_to_any_value(value, true, allow_object).map_err(|_| {
PyTypeError::new_err(
format!(
"cannot create expression literal for value of type {}.\
\n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.",
value.get_type().qualname().unwrap_or("unknown".to_owned()),
)
)
})?;
match av {
AnyValue::ObjectOwned(_) => {
let s = Python::with_gil(|py| {
PySeries::new_object(py, "", vec![ObjectValue::from(value.into_py(py))], false)
.series
});
Ok(dsl::lit(s).into())
} else {
Err(PyTypeError::new_err(format!(
"cannot create expression literal for value of type {}: {}\
\n\nHint: Pass `allow_object=True` to accept any value and create a literal of type Object.",
value.get_type().qualname()?,
value.repr()?
)))
}
})
},
_ => Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()),
}
}
}

Expand Down
17 changes: 12 additions & 5 deletions crates/polars-python/src/series/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,26 @@ init_method_opt!(new_opt_i64, Int64Type, i64);
init_method_opt!(new_opt_f32, Float32Type, f32);
init_method_opt!(new_opt_f64, Float64Type, f64);

fn convert_to_avs<'a>(values: &'a Bound<'a, PyAny>, strict: bool) -> PyResult<Vec<AnyValue<'a>>> {
fn convert_to_avs<'a>(
values: &'a Bound<'a, PyAny>,
strict: bool,
allow_object: bool,
) -> PyResult<Vec<AnyValue<'a>>> {
values
.iter()?
.map(|v| py_object_to_any_value(&(v?).as_borrowed(), strict))
.map(|v| py_object_to_any_value(&(v?).as_borrowed(), strict, allow_object))
.collect()
}

#[pymethods]
impl PySeries {
#[staticmethod]
fn new_from_any_values(name: &str, values: &Bound<PyAny>, strict: bool) -> PyResult<Self> {
let avs = convert_to_avs(values, strict);
let result = avs.and_then(|avs| {
let any_values_result = values
.iter()?
.map(|v| py_object_to_any_value(&(v?).as_borrowed(), strict, true))
.collect::<PyResult<Vec<AnyValue>>>();
let result = any_values_result.and_then(|avs| {
let s = Series::from_any_values(name.into(), avs.as_slice(), strict).map_err(|e| {
PyTypeError::new_err(format!(
"{e}\n\nHint: Try setting `strict=False` to allow passing data with mixed types."
Expand Down Expand Up @@ -215,7 +222,7 @@ impl PySeries {
dtype: Wrap<DataType>,
strict: bool,
) -> PyResult<Self> {
let avs = convert_to_avs(values, strict)?;
let avs = convert_to_avs(values, strict, false)?;
let s = Series::from_any_values_and_dtype(name.into(), avs.as_slice(), &dtype.0, strict)
.map_err(|e| {
PyTypeError::new_err(format!(
Expand Down
38 changes: 17 additions & 21 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import enum
from datetime import date, datetime, timedelta
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -100,7 +100,7 @@ def test_lit_int_return_type(input: int, dtype: PolarsDataType) -> None:
def test_lit_unsupported_type() -> None:
with pytest.raises(
TypeError,
match="cannot create expression literal for value of type LazyFrame: ",
match="cannot create expression literal for value of type LazyFrame",
):
pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))

Expand Down Expand Up @@ -197,25 +197,21 @@ def test_lit_decimal_parametric(s: pl.Series) -> None:
assert result == value


def test_lit_datetime_subclass_w_allow_object() -> None:
class MyAmazingDate(date):
pass

class MyAmazingDatetime(datetime):
@pytest.mark.parametrize(
("dt_class", "input"),
[
(date, (2024, 1, 1)),
(datetime, (2024, 1, 1)),
(timedelta, (1,)),
(time, (1,)),
],
)
def test_lit_temporal_subclass_w_allow_object(
dt_class: type, input: tuple[int]
) -> None:
class MyClass(dt_class): # type: ignore[misc]
pass

result = pl.select(
a=pl.lit(MyAmazingDatetime(2020, 1, 1)),
b=pl.lit(MyAmazingDate(2020, 1, 1)),
c=pl.lit(MyAmazingDatetime(2020, 1, 1), allow_object=True),
d=pl.lit(MyAmazingDate(2020, 1, 1), allow_object=True),
)
expected = pl.DataFrame(
{
"a": [datetime(2020, 1, 1)],
"b": [date(2020, 1, 1)],
"c": [datetime(2020, 1, 1)],
"d": [date(2020, 1, 1)],
}
)
result = pl.select(a=pl.lit(MyClass(*input)))
expected = pl.DataFrame({"a": [dt_class(*input)]})
assert_frame_equal(result, expected)

0 comments on commit a61f379

Please sign in to comment.