Skip to content

Commit 1f19580

Browse files
committed
Refactor common logic
1 parent e0b18da commit 1f19580

File tree

1 file changed

+79
-25
lines changed

1 file changed

+79
-25
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,72 @@ impl DecimalCast for i256 {
139139
}
140140
}
141141

142+
/// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale)
143+
/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion.
144+
pub(crate) fn rescale_decimal<I, O>(
145+
_input_precision: u8,
146+
input_scale: i8,
147+
_output_precision: u8,
148+
output_scale: i8,
149+
) -> impl Fn(I::Native) -> Option<O::Native>
150+
where
151+
I: DecimalType,
152+
O: DecimalType,
153+
I::Native: DecimalCast + ArrowNativeTypeOp,
154+
O::Native: DecimalCast + ArrowNativeTypeOp,
155+
{
156+
let delta_scale = output_scale - input_scale;
157+
158+
// Precompute parameters and capture them in a single closure type
159+
let mul_opt = if delta_scale > 0 {
160+
O::Native::from_decimal(10_i128)
161+
.and_then(|t| t.pow_checked(delta_scale as u32).ok())
162+
} else {
163+
None
164+
};
165+
166+
let (div_opt, half_opt, half_neg_opt) = if delta_scale < 0 {
167+
let div = I::Native::from_decimal(10_i128)
168+
.and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as u32).ok());
169+
if let Some(div) = div {
170+
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
171+
let half_neg = half.neg_wrapping();
172+
(Some(div), Some(half), Some(half_neg))
173+
} else {
174+
(None, None, None)
175+
}
176+
} else {
177+
(None, None, None)
178+
};
179+
180+
move |x: I::Native| {
181+
if delta_scale == 0 {
182+
return O::Native::from_decimal(x);
183+
}
184+
185+
if let Some(mul) = mul_opt {
186+
return O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
187+
}
188+
189+
// Decrease scale path
190+
let div = div_opt.unwrap();
191+
let half = half_opt.unwrap();
192+
let half_neg = half_neg_opt.unwrap();
193+
194+
// div is >= 10 and so this cannot overflow
195+
let d = x.div_wrapping(div);
196+
let r = x.mod_wrapping(div);
197+
198+
// Round result
199+
let adjusted = match x >= I::Native::ZERO {
200+
true if r >= half => d.add_wrapping(I::Native::ONE),
201+
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
202+
_ => d,
203+
};
204+
O::Native::from_decimal(adjusted)
205+
}
206+
}
207+
142208
pub(crate) fn cast_decimal_to_decimal_error<I, O>(
143209
output_precision: u8,
144210
output_scale: i8,
@@ -188,26 +254,12 @@ where
188254
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
189255
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
190256

191-
let div = I::Native::from_decimal(10_i128)
192-
.unwrap()
193-
.pow_checked(delta_scale as u32)?;
194-
195-
let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
196-
let half_neg = half.neg_wrapping();
197-
198-
let f = |x: I::Native| {
199-
// div is >= 10 and so this cannot overflow
200-
let d = x.div_wrapping(div);
201-
let r = x.mod_wrapping(div);
202-
203-
// Round result
204-
let adjusted = match x >= I::Native::ZERO {
205-
true if r >= half => d.add_wrapping(I::Native::ONE),
206-
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
207-
_ => d,
208-
};
209-
O::Native::from_decimal(adjusted)
210-
};
257+
let f = rescale_decimal::<I, O>(
258+
input_precision,
259+
input_scale,
260+
output_precision,
261+
output_scale,
262+
);
211263

212264
Ok(if is_infallible_cast {
213265
// make sure we don't perform calculations that don't make sense w/o validation
@@ -242,9 +294,6 @@ where
242294
{
243295
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
244296
let delta_scale = output_scale - input_scale;
245-
let mul = O::Native::from_decimal(10_i128)
246-
.unwrap()
247-
.pow_checked(delta_scale as u32)?;
248297

249298
// if the gain in precision (digits) is greater than the multiplication due to scaling
250299
// every number will fit into the output type
@@ -253,13 +302,18 @@ where
253302
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
254303
// needs to provide at least 8 digits precision
255304
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
256-
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());
305+
let f = rescale_decimal::<I, O>(
306+
input_precision,
307+
input_scale,
308+
output_precision,
309+
output_scale,
310+
);
257311

258312
Ok(if is_infallible_cast {
259313
// make sure we don't perform calculations that don't make sense w/o validation
260314
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
261315
// unwrapping is safe since the result is guaranteed to fit into the target type
262-
let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
316+
let f = |x: I::Native| f(x).unwrap();
263317
array.unary(f)
264318
} else if cast_options.safe {
265319
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))

0 commit comments

Comments
 (0)