-
Notifications
You must be signed in to change notification settings - Fork 1k
[Variant] Support variant to Decimal32/64/128/256
#8552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f25b499
7a32191
02d29de
f498db5
964e45a
43d579d
6f39a2a
8f0f53c
d88fd7f
522b26a
9b6d0e1
54237fe
e0b18da
1f19580
c163a91
274a028
a7cdd33
94d60c0
338defe
e1febf6
51648fd
a48bbf4
cb2576c
5ffab93
ef62474
21a83ed
25e4aa9
539d73f
dfe9960
cfc8580
9ed0d7a
0567cb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,8 +17,13 @@ | |||||||||
|
|
||||||||||
| //! Module for transforming a typed arrow `Array` to `VariantArray`. | ||||||||||
|
|
||||||||||
| use arrow::datatypes::{self, ArrowPrimitiveType, ArrowTimestampType, Date32Type}; | ||||||||||
| use parquet_variant::Variant; | ||||||||||
| use arrow::array::ArrowNativeTypeOp; | ||||||||||
| use arrow::compute::DecimalCast; | ||||||||||
| use arrow::datatypes::{ | ||||||||||
| self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, | ||||||||||
| DecimalType, | ||||||||||
| }; | ||||||||||
| use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16}; | ||||||||||
|
|
||||||||||
| /// Options for controlling the behavior of `cast_to_variant_with_options`. | ||||||||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||||||||
|
|
@@ -82,7 +87,7 @@ impl_primitive_from_variant!(datatypes::Float64Type, as_f64); | |||||||||
| impl_primitive_from_variant!( | ||||||||||
| datatypes::Date32Type, | ||||||||||
| as_naive_date, | ||||||||||
| Date32Type::from_naive_date | ||||||||||
| datatypes::Date32Type::from_naive_date | ||||||||||
| ); | ||||||||||
| impl_timestamp_from_variant!( | ||||||||||
| datatypes::TimestampMicrosecondType, | ||||||||||
|
|
@@ -109,6 +114,171 @@ impl_timestamp_from_variant!( | |||||||||
| |timestamp| Self::make_value(timestamp.naive_utc()) | ||||||||||
| ); | ||||||||||
|
|
||||||||||
| /// Returns the unscaled integer representation for Arrow decimal type `O` | ||||||||||
| /// from a `Variant`. | ||||||||||
| /// | ||||||||||
| /// - `precision` and `scale` specify the target Arrow decimal parameters | ||||||||||
| /// - Integer variants (`Int8/16/32/64`) are treated as decimals with scale 0 | ||||||||||
| /// - Decimal variants (`Decimal4/8/16`) use their embedded precision and scale | ||||||||||
| /// | ||||||||||
| /// The value is rescaled to (`precision`, `scale`) using `rescale_decimal` and | ||||||||||
| /// returns `None` if it cannot fit the requested precision. | ||||||||||
| pub(crate) fn variant_to_unscaled_decimal<O>( | ||||||||||
| variant: &Variant<'_, '_>, | ||||||||||
| precision: u8, | ||||||||||
| scale: i8, | ||||||||||
| ) -> Option<O::Native> | ||||||||||
| where | ||||||||||
| O: DecimalType, | ||||||||||
| O::Native: DecimalCast, | ||||||||||
| { | ||||||||||
| match variant { | ||||||||||
| Variant::Int8(i) => rescale_decimal::<Decimal32Type, O>( | ||||||||||
| *i as i32, | ||||||||||
| VariantDecimal4::MAX_PRECISION, | ||||||||||
| 0, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Int16(i) => rescale_decimal::<Decimal32Type, O>( | ||||||||||
| *i as i32, | ||||||||||
| VariantDecimal4::MAX_PRECISION, | ||||||||||
| 0, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Int32(i) => rescale_decimal::<Decimal32Type, O>( | ||||||||||
| *i, | ||||||||||
| VariantDecimal4::MAX_PRECISION, | ||||||||||
| 0, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Int64(i) => rescale_decimal::<Decimal64Type, O>( | ||||||||||
| *i, | ||||||||||
| VariantDecimal8::MAX_PRECISION, | ||||||||||
| 0, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Decimal4(d) => rescale_decimal::<Decimal32Type, O>( | ||||||||||
| d.integer(), | ||||||||||
| VariantDecimal4::MAX_PRECISION, | ||||||||||
| d.scale() as i8, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Decimal8(d) => rescale_decimal::<Decimal64Type, O>( | ||||||||||
| d.integer(), | ||||||||||
| VariantDecimal8::MAX_PRECISION, | ||||||||||
| d.scale() as i8, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| Variant::Decimal16(d) => rescale_decimal::<Decimal128Type, O>( | ||||||||||
| d.integer(), | ||||||||||
| VariantDecimal16::MAX_PRECISION, | ||||||||||
| d.scale() as i8, | ||||||||||
| precision, | ||||||||||
| scale, | ||||||||||
| ), | ||||||||||
| _ => None, | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) | ||||||||||
| /// and return the scaled value if it fits the output precision. Similar to the implementation in | ||||||||||
| /// decimal.rs in arrow-cast. | ||||||||||
| pub(crate) fn rescale_decimal<I, O>( | ||||||||||
| value: I::Native, | ||||||||||
| input_precision: u8, | ||||||||||
| input_scale: i8, | ||||||||||
| output_precision: u8, | ||||||||||
| output_scale: i8, | ||||||||||
| ) -> Option<O::Native> | ||||||||||
| where | ||||||||||
| I: DecimalType, | ||||||||||
| O: DecimalType, | ||||||||||
| I::Native: DecimalCast, | ||||||||||
| O::Native: DecimalCast, | ||||||||||
| { | ||||||||||
| let delta_scale = output_scale - input_scale; | ||||||||||
|
|
||||||||||
| // Determine if the cast is infallible based on precision/scale math | ||||||||||
| let is_infallible_cast = | ||||||||||
| is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); | ||||||||||
|
|
||||||||||
|
Comment on lines
+206
to
+210
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: move this whole block down to where we actually use it -- declare near first (only) use |
||||||||||
| let scaled = if delta_scale == 0 { | ||||||||||
| O::Native::from_decimal(value) | ||||||||||
| } else if delta_scale > 0 { | ||||||||||
| let mul = O::Native::from_decimal(10_i128) | ||||||||||
| .and_then(|t| t.pow_checked(delta_scale as u32).ok())?; | ||||||||||
|
Comment on lines
+214
to
+215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could use the same performance optimization as the negative scale case below:
Suggested change
(it didn't matter much in the columnar decimal cast code, but it probably does matter in row-wise variant cast code)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also -- we should benchmark, but it might be faster to multiply by one than to execute the branch that distinguishes between zero and positive delta scale. If so, we would want code like this: let (scaled, is_infallible_cast) = if delta_scale < 0 {
// ... big comment about why ...
let is_infallible = input_precision + delta_scale < output_precision;
// ... comment about dividing out too many digits ...
let delta_scale = delta_scale.unsigned_abs() as usize;
let Some(max) = ... else { ... return zero ... };
...
(O::Native::from_decimal(adjusted)?, is_infallible_cast)
} else {
// ... big comment explaining why ...
let is_infallible_cast = input_precision + delta_scale <= output_precision;
let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale)?;
let mul = max.add_wrapping(O::Native::ONE);
let x = O::Native::from_decimal(value)?;
(x.mul_checked(mul).ok()?, is_infallible_cast)
}
(is_infallible_cast || O::is_valid_decimal_precision(scaled, output_precision)).then(scaled) |
||||||||||
| O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok()) | ||||||||||
| } else { | ||||||||||
| // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the | ||||||||||
| // scale change divides out more digits than the input has precision and the result of the cast | ||||||||||
| // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest | ||||||||||
| // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values | ||||||||||
| // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even | ||||||||||
| // smaller results, which also round to zero. In that case, just return an array of zeros. | ||||||||||
| let delta_scale = delta_scale.unsigned_abs() as usize; | ||||||||||
| let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else { | ||||||||||
| return Some(O::Native::ZERO); | ||||||||||
| }; | ||||||||||
| let div = max.add_wrapping(I::Native::ONE); | ||||||||||
| let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); | ||||||||||
| let half_neg = half.neg_wrapping(); | ||||||||||
|
|
||||||||||
| // div is >= 10 and so this cannot overflow | ||||||||||
| let d = value.div_wrapping(div); | ||||||||||
| let r = value.mod_wrapping(div); | ||||||||||
|
|
||||||||||
| // Round result | ||||||||||
| let adjusted = match value >= I::Native::ZERO { | ||||||||||
| true if r >= half => d.add_wrapping(I::Native::ONE), | ||||||||||
| false if r <= half_neg => d.sub_wrapping(I::Native::ONE), | ||||||||||
| _ => d, | ||||||||||
| }; | ||||||||||
| O::Native::from_decimal(adjusted) | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision)) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /// Returns true if casting from (input_precision, input_scale) to | ||||||||||
| /// (output_precision, output_scale) is infallible based on precision/scale math. | ||||||||||
| fn is_infallible_decimal_cast( | ||||||||||
| input_precision: u8, | ||||||||||
| input_scale: i8, | ||||||||||
| output_precision: u8, | ||||||||||
| output_scale: i8, | ||||||||||
| ) -> bool { | ||||||||||
| let delta_scale = output_scale - input_scale; | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we could have passed this in, our caller already computed it. but I guess this is more regular? |
||||||||||
| let input_precision = input_precision as i8; | ||||||||||
| let output_precision = output_precision as i8; | ||||||||||
| if delta_scale >= 0 { | ||||||||||
| // if the gain in precision (digits) is greater than the multiplication due to scaling | ||||||||||
| // every number will fit into the output type | ||||||||||
| // Example: If we are starting with any number of precision 5 [xxxxx], | ||||||||||
| // then an increase of scale by 3 will have the following effect on the representation: | ||||||||||
| // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type | ||||||||||
| // needs to provide at least 8 digits precision | ||||||||||
| input_precision + delta_scale <= output_precision | ||||||||||
| } else { | ||||||||||
| // if the reduction of the input number through scaling (dividing) is greater | ||||||||||
| // than a possible precision loss (plus potential increase via rounding) | ||||||||||
| // every input number will fit into the output type | ||||||||||
| // Example: If we are starting with any number of precision 5 [xxxxx], | ||||||||||
| // then and decrease the scale by 3 will have the following effect on the representation: | ||||||||||
| // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). | ||||||||||
| // The rounding may add an additional digit, so for the cast to be infallible, | ||||||||||
| // the output type needs to have at least 3 digits of precision. | ||||||||||
| // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: | ||||||||||
| // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible | ||||||||||
| input_precision + delta_scale < output_precision | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /// Convert the value at a specific index in the given array into a `Variant`. | ||||||||||
| macro_rules! non_generic_conversion_single_value { | ||||||||||
| ($array:expr, $cast_fn:expr, $index:expr) => {{ | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny nit to consider (saves space)