Skip to content

Commit 274a028

Browse files
committed
Refactor common logic
1 parent c163a91 commit 274a028

File tree

1 file changed

+42
-26
lines changed

1 file changed

+42
-26
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,14 @@ pub(crate) fn rescale_decimal<I, O>(
150150
where
151151
I: DecimalType,
152152
O: DecimalType,
153-
I::Native: DecimalCast + ArrowNativeTypeOp,
154-
O::Native: DecimalCast + ArrowNativeTypeOp,
153+
I::Native: DecimalCast,
154+
O::Native: DecimalCast,
155155
{
156156
let delta_scale = output_scale - input_scale;
157-
let input_precision_i8 = input_precision as i8;
158-
let output_precision_i8 = output_precision as i8;
159157

160158
// Determine if the cast is infallible based on precision/scale math
161-
let is_infallible_cast = input_precision_i8 + delta_scale < output_precision_i8;
159+
let is_infallible_cast =
160+
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
162161

163162
// Build a single mode once and use a thin closure that calls into it
164163
enum RescaleMode<I, O> {
@@ -177,7 +176,6 @@ where
177176
None => RescaleMode::Invalid,
178177
}
179178
} else {
180-
// delta_scale < 0
181179
match I::Native::from_decimal(10_i128)
182180
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok())
183181
{
@@ -234,6 +232,40 @@ where
234232
}
235233
}
236234

235+
/// Returns true if casting from (input_precision, input_scale) to
236+
/// (output_precision, output_scale) is infallible based on precision/scale math.
237+
fn is_infallible_decimal_cast(
238+
input_precision: u8,
239+
input_scale: i8,
240+
output_precision: u8,
241+
output_scale: i8,
242+
) -> bool {
243+
let delta_scale = output_scale - input_scale;
244+
let input_precision_i8 = input_precision as i8;
245+
let output_precision_i8 = output_precision as i8;
246+
if delta_scale >= 0 {
247+
// if the gain in precision (digits) is greater than the multiplication due to scaling
248+
// every number will fit into the output type
249+
// Example: If we are starting with any number of precision 5 [xxxxx],
250+
// then an increase of scale by 3 will have the following effect on the representation:
251+
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
252+
// needs to provide at least 8 digits precision
253+
input_precision_i8 + delta_scale <= output_precision_i8
254+
} else {
255+
// if the reduction of the input number through scaling (dividing) is greater
256+
// than a possible precision loss (plus potential increase via rounding)
257+
// every input number will fit into the output type
258+
// Example: If we are starting with any number of precision 5 [xxxxx],
259+
// then and decrease the scale by 3 will have the following effect on the representation:
260+
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
261+
// The rounding may add an additional digit, so the cast to be infallible,
262+
// the output type needs to have at least 3 digits of precision.
263+
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
264+
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
265+
input_precision_i8 + delta_scale < output_precision_i8
266+
}
267+
}
268+
237269
pub(crate) fn cast_decimal_to_decimal_error<I, O>(
238270
output_precision: u8,
239271
output_scale: i8,
@@ -271,18 +303,8 @@ where
271303
{
272304
// make sure we don't perform calculations that don't make sense w/o validation
273305
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
274-
let delta_scale = input_scale - output_scale;
275-
// if the reduction of the input number through scaling (dividing) is greater
276-
// than a possible precision loss (plus potential increase via rounding)
277-
// every input number will fit into the output type
278-
// Example: If we are starting with any number of precision 5 [xxxxx],
279-
// then and decrease the scale by 3 will have the following effect on the representation:
280-
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
281-
// The rounding may add an additional digit, so the cast to be infallible,
282-
// the output type needs to have at least 3 digits of precision.
283-
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
284-
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
285-
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
306+
let is_infallible_cast =
307+
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
286308

287309
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);
288310

@@ -312,15 +334,9 @@ where
312334
{
313335
// make sure we don't perform calculations that don't make sense w/o validation
314336
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
315-
let delta_scale = output_scale - input_scale;
316337

317-
// if the gain in precision (digits) is greater than the multiplication due to scaling
318-
// every number will fit into the output type
319-
// Example: If we are starting with any number of precision 5 [xxxxx],
320-
// then an increase of scale by 3 will have the following effect on the representation:
321-
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
322-
// needs to provide at least 8 digits precision
323-
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
338+
let is_infallible_cast =
339+
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
324340
let f = rescale_decimal::<I, O>(input_precision, input_scale, output_precision, output_scale);
325341

326342
Ok(if is_infallible_cast {

0 commit comments

Comments
 (0)