Skip to content

Commit

Permalink
pow int float
Browse files Browse the repository at this point in the history
  • Loading branch information
CanglongCl committed Apr 6, 2024
1 parent d01d82c commit d4f95d7
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 85 deletions.
13 changes: 13 additions & 0 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,19 @@ macro_rules! with_match_physical_integer_type {(
}
})}

#[macro_export]
macro_rules! with_match_physical_float_type {(
$dtype:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use $crate::datatypes::DataType::*;
match $dtype {
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

#[macro_export]
macro_rules! with_match_physical_float_polars_type {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
Expand Down
107 changes: 48 additions & 59 deletions crates/polars-plan/src/dsl/function_expr/pow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use arrow::legacy::kernels::pow::pow as pow_kernel;
use num::pow::Pow;
use polars_core::export::num;
use polars_core::export::num::{Float, ToPrimitive};
use polars_core::{with_match_physical_float_type, with_match_physical_integer_type, with_match_physical_float_polars_type};

use super::*;

Expand Down Expand Up @@ -128,65 +129,53 @@ where

fn pow_on_series(base: &Series, exponent: &Series) -> PolarsResult<Option<Series>> {
use DataType::*;
match (base.dtype(), exponent.dtype()) {
#[cfg(feature = "dtype-u8")]
(UInt8, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u8().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-i8")]
(Int8, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i8().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-u16")]
(UInt16, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u16().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
#[cfg(feature = "dtype-i16")]
(Int16, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i16().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(UInt32, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u32().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Int32, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i32().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(UInt64, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.u64().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Int64, UInt8 | UInt16 | UInt32 | UInt64) => {
let ca = base.i64().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
},
(Float32, _) => {
let ca = base.f32().unwrap();
let exponent = exponent.strict_cast(&DataType::Float32)?;
pow_on_floats(ca, exponent.f32().unwrap())
},
(Float64, _) => {
let ca = base.f64().unwrap();
let exponent = exponent.strict_cast(&DataType::Float64)?;
pow_on_floats(ca, exponent.f64().unwrap())
},
_ => {
let base = base.cast(&DataType::Float64)?;
pow_on_series(&base, exponent)
},

let base_dtype = base.dtype();
polars_ensure!(
base_dtype.is_numeric(),
InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
);
let expoent_dtype = exponent.dtype();
polars_ensure!(
expoent_dtype.is_numeric(),
InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", expoent_dtype
);

// if false, dtype is float
if base_dtype.is_integer() {
with_match_physical_integer_type!(base_dtype, |$native_type| {
if expoent_dtype.is_float() {
match expoent_dtype {
Float32 => {
let ca = base.cast(&DataType::Float32)?;
pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
},
Float64 => {
let ca = base.cast(&DataType::Float64)?;
pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
},
_ => unreachable!(),
}
} else {
let ca = base.$native_type().unwrap();
let exponent = exponent.strict_cast(&DataType::UInt32)?;
pow_to_uint_dtype(ca, exponent.u32().unwrap())
}
})
} else {
match base_dtype {
Float32 => {
let ca = base.f32().unwrap();
let exponent = exponent.strict_cast(&DataType::Float32)?;
pow_on_floats(ca, exponent.f32().unwrap())
},
Float64 => {
let ca = base.f64().unwrap();
let exponent = exponent.strict_cast(&DataType::Float64)?;
pow_on_floats(ca, exponent.f64().unwrap())
},
_ => unreachable!(),
}
}
}

Expand Down
18 changes: 10 additions & 8 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,16 @@ impl<'a> FieldsMapper<'a> {
}

pub(super) fn pow_dtype(&self) -> PolarsResult<Field> {
// base, exponent
match (self.fields[0].data_type(), self.fields[1].data_type()) {
(
base_dtype,
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64,
) => Ok(Field::new(self.fields[0].name(), base_dtype.clone())),
(DataType::Float32, _) => Ok(Field::new(self.fields[0].name(), DataType::Float32)),
(_, _) => Ok(Field::new(self.fields[0].name(), DataType::Float64)),
let base_dtype = self.fields[0].data_type();
let expoent_dtype = self.fields[1].data_type();
if base_dtype.is_integer() {
if expoent_dtype.is_float() {
Ok(Field::new(self.fields[0].name(), expoent_dtype.clone()))
} else {
Ok(Field::new(self.fields[0].name(), base_dtype.clone()))
}
} else {
Ok(Field::new(self.fields[0].name(), base_dtype.clone()))
}
}

Expand Down
6 changes: 0 additions & 6 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,9 +1167,6 @@ def __pow__(self, exponent: int | float | Series) -> Series:
return self.pow(exponent)

def __rpow__(self, other: Any) -> Series:
if self.dtype.is_temporal():
msg = "first cast to integer before raising datelike dtypes to a power"
raise TypeError(msg)
return self.to_frame().select_seq(other ** F.col(self.name)).to_series()

def __matmul__(self, other: Any) -> float | Series | None:
Expand Down Expand Up @@ -1965,9 +1962,6 @@ def pow(self, exponent: int | float | Series) -> Series:
64.0
]
"""
if self.dtype.is_temporal():
msg = "first cast to integer before raising datelike dtypes to a power"
raise TypeError(msg)
if _check_for_numpy(exponent) and isinstance(exponent, np.ndarray):
exponent = Series(exponent)
return self.to_frame().select_seq(F.col(self.name).pow(exponent)).to_series()
Expand Down
40 changes: 29 additions & 11 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,9 @@ def test_arithmetic_datetime() -> None:
a * 2
with pytest.raises(TypeError):
a % 2
with pytest.raises(TypeError):
with pytest.raises(
pl.InvalidOperationError,
):
a**2
with pytest.raises(TypeError):
2 / a
Expand All @@ -457,7 +459,9 @@ def test_arithmetic_datetime() -> None:
2 * a
with pytest.raises(TypeError):
2 % a
with pytest.raises(TypeError):
with pytest.raises(
pl.InvalidOperationError,
):
2**a


Expand All @@ -476,12 +480,11 @@ def test_power() -> None:
m = pl.Series([2**33, 2**33], dtype=UInt64)

# pow
assert_series_equal(a**2, pl.Series([1.0, 4.0], dtype=Float64))
assert_series_equal(a**2, pl.Series([1, 4], dtype=Int64))
assert_series_equal(b**3, pl.Series([None, 8.0], dtype=Float64))
assert_series_equal(a**a, pl.Series([1.0, 4.0], dtype=Float64))
assert_series_equal(a**a, pl.Series([1, 4], dtype=Int64))
assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64))
assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64))
assert_series_equal(a**None, pl.Series([None] * len(a), dtype=Float64)) # type: ignore[operator]
assert_series_equal(d**d, pl.Series([1, 4], dtype=UInt8))
assert_series_equal(e**d, pl.Series([1, 4], dtype=Int8))
assert_series_equal(f**d, pl.Series([1, 4], dtype=UInt16))
Expand All @@ -490,8 +493,24 @@ def test_power() -> None:
assert_series_equal(i**d, pl.Series([1, 4], dtype=Int32))
assert_series_equal(j**d, pl.Series([1, 4], dtype=UInt64))
assert_series_equal(k**d, pl.Series([1, 4], dtype=Int64))
with pytest.raises(TypeError):

with pytest.raises(
pl.InvalidOperationError,
match="`pow` operation not supported for dtype `null` as exponent"
):
a**None

with pytest.raises(
pl.InvalidOperationError,
match="`pow` operation not supported for dtype `date` as base"
):
c**2
with pytest.raises(
pl.InvalidOperationError,
match="`pow` operation not supported for dtype `date` as exponent"
):
2**c

with pytest.raises(pl.ColumnNotFoundError):
a ** "hi" # type: ignore[operator]

Expand All @@ -504,13 +523,12 @@ def test_power() -> None:
# rpow
assert_series_equal(2.0**a, pl.Series("literal", [2.0, 4.0], dtype=Float64))
assert_series_equal(2**b, pl.Series("literal", [None, 4.0], dtype=Float64))
with pytest.raises(TypeError):
2**c

with pytest.raises(pl.ColumnNotFoundError):
"hi" ** a

# Series.pow() method
assert_series_equal(a.pow(2), pl.Series([1.0, 4.0], dtype=Float64))
assert_series_equal(a.pow(2), pl.Series([1, 4], dtype=Int64))


def test_add_string() -> None:
Expand Down Expand Up @@ -1986,13 +2004,13 @@ def test_cumulative_eval() -> None:
expr2 = pl.element().last() ** 2

expected1 = pl.Series("values", [1, 1, 1, 1, 1])
expected2 = pl.Series("values", [1.0, 4.0, 9.0, 16.0, 25.0])
expected2 = pl.Series("values", [1, 4, 9, 16, 25])
assert_series_equal(s.cumulative_eval(expr1), expected1)
assert_series_equal(s.cumulative_eval(expr2), expected2)

# evaluate combined expressions and validate
expr3 = expr1 - expr2
expected3 = pl.Series("values", [0.0, -3.0, -8.0, -15.0, -24.0])
expected3 = pl.Series("values", [0, -3, -8, -15, -24])
assert_series_equal(s.cumulative_eval(expr3), expected3)


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/sql/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_sql_expr() -> None:
)
result = df.select(*sql_exprs)
expected = pl.DataFrame(
{"a": [1, 1, 1], "aa": [1.0, 4.0, 27.0], "b2": ["yz", "bc", None]}
{"a": [1, 1, 1], "aa": [1, 4, 27], "b2": ["yz", "bc", None]}
)
assert_frame_equal(result, expected)

Expand Down

0 comments on commit d4f95d7

Please sign in to comment.