Skip to content

Commit

Permalink
feat(python): Support Decimal inputs for lit (pola-rs#16950)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored and Wouittone committed Jun 22, 2024
1 parent f4d585c commit 09d1f37
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 1 deletion.
4 changes: 4 additions & 0 deletions crates/polars-expr/src/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ impl PhysicalExpr for LiteralExpr {
UInt64(v) => UInt64Chunked::full(LITERAL_NAME, *v, 1).into_series(),
Float32(v) => Float32Chunked::full(LITERAL_NAME, *v, 1).into_series(),
Float64(v) => Float64Chunked::full(LITERAL_NAME, *v, 1).into_series(),
#[cfg(feature = "dtype-decimal")]
Decimal(v, scale) => Int128Chunked::full(LITERAL_NAME, *v, 1)
.into_decimal_unchecked(None, *scale)
.into_series(),
Boolean(v) => BooleanChunked::full(LITERAL_NAME, *v, 1).into_series(),
Null => polars_core::prelude::Series::new_null(LITERAL_NAME, 1),
Range {
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ pub enum LiteralValue {
Float32(f32),
/// A 64-bit floating point number.
Float64(f64),
/// A 128-bit decimal number with a maximum scale of 38.
#[cfg(feature = "dtype-decimal")]
Decimal(i128, usize),
Range {
low: i64,
high: i64,
Expand Down Expand Up @@ -121,6 +124,8 @@ impl LiteralValue {
Int64(v) => AnyValue::Int64(*v),
Float32(v) => AnyValue::Float32(*v),
Float64(v) => AnyValue::Float64(*v),
#[cfg(feature = "dtype-decimal")]
Decimal(v, scale) => AnyValue::Decimal(*v, *scale),
String(v) => AnyValue::String(v),
#[cfg(feature = "dtype-duration")]
Duration(v, tu) => AnyValue::Duration(*v, *tu),
Expand Down Expand Up @@ -192,6 +197,8 @@ impl LiteralValue {
LiteralValue::Int64(_) => DataType::Int64,
LiteralValue::Float32(_) => DataType::Float32,
LiteralValue::Float64(_) => DataType::Float64,
#[cfg(feature = "dtype-decimal")]
LiteralValue::Decimal(_, scale) => DataType::Decimal(None, Some(*scale)),
LiteralValue::String(_) => DataType::String,
LiteralValue::Binary(_) => DataType::Binary,
LiteralValue::Range { data_type, .. } => data_type.clone(),
Expand Down Expand Up @@ -276,6 +283,8 @@ impl TryFrom<AnyValue<'_>> for LiteralValue {
AnyValue::Int64(i) => Ok(Self::Int64(i)),
AnyValue::Float32(f) => Ok(Self::Float32(f)),
AnyValue::Float64(f) => Ok(Self::Float64(f)),
#[cfg(feature = "dtype-decimal")]
AnyValue::Decimal(v, scale) => Ok(Self::Decimal(v, scale)),
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => Ok(LiteralValue::Date(v)),
#[cfg(feature = "dtype-datetime")]
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyBytes, PyFloat, PyInt, PyString};

use crate::conversion::any_value::py_object_to_any_value;
use crate::conversion::{get_lf, Wrap};
use crate::expr::ToExprs;
use crate::map::lazy::binary_lambda;
Expand Down Expand Up @@ -428,6 +429,9 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult<PyExpr> {
Ok(dsl::lit(Null {}).into())
} else if let Ok(value) = value.downcast::<PyBytes>() {
Ok(dsl::lit(value.as_bytes()).into())
} else if value.get_type().qualname().unwrap() == "Decimal" {
let av = py_object_to_any_value(value, true)?;
Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into())
} else if allow_object {
let s = Python::with_gil(|py| {
PySeries::new_object(py, "", vec![ObjectValue::from(value.into_py(py))], false).series
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
},
Binary(_) => return Err(PyNotImplementedError::new_err("binary literal")),
Range { .. } => return Err(PyNotImplementedError::new_err("range literal")),
Date(..) | DateTime(..) => Literal {
Date(..) | DateTime(..) | Decimal(..) => Literal {
value: Wrap(lit.to_any_value().unwrap()).to_object(py),
dtype,
},
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Any

import numpy as np
Expand All @@ -10,6 +11,7 @@

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing.parametric.strategies import series
from polars.testing.parametric.strategies.data import datetimes


Expand Down Expand Up @@ -155,3 +157,27 @@ def test_datetime_ms(value: datetime) -> None:
result = pl.select(pl.lit(value, dtype=pl.Datetime("ms")))["literal"][0]
expected_microsecond = value.microsecond // 1000 * 1000
assert result == value.replace(microsecond=expected_microsecond)


def test_lit_decimal() -> None:
value = Decimal("0.1")

expr = pl.lit(value)
df = pl.select(expr)
result = df.item()

assert df.dtypes[0] == pl.Decimal(None, 1)
assert result == value


@given(s=series(min_size=1, max_size=1, allow_null=False, allowed_dtypes=pl.Decimal))
def test_lit_decimal_parametric(s: pl.Series) -> None:
scale = s.dtype.scale # type: ignore[attr-defined]
value = s.item()

expr = pl.lit(value)
df = pl.select(expr)
result = df.item()

assert df.dtypes[0] == pl.Decimal(None, scale)
assert result == value

0 comments on commit 09d1f37

Please sign in to comment.