@@ -139,6 +139,72 @@ impl DecimalCast for i256 {
139139 }
140140}
141141
142+ /// Build a rescale function from (input_precision, input_scale) to (output_precision, output_scale)
143+ /// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the conversion.
144+ pub ( crate ) fn rescale_decimal < I , O > (
145+ _input_precision : u8 ,
146+ input_scale : i8 ,
147+ _output_precision : u8 ,
148+ output_scale : i8 ,
149+ ) -> impl Fn ( I :: Native ) -> Option < O :: Native >
150+ where
151+ I : DecimalType ,
152+ O : DecimalType ,
153+ I :: Native : DecimalCast + ArrowNativeTypeOp ,
154+ O :: Native : DecimalCast + ArrowNativeTypeOp ,
155+ {
156+ let delta_scale = output_scale - input_scale;
157+
158+ // Precompute parameters and capture them in a single closure type
159+ let mul_opt = if delta_scale > 0 {
160+ O :: Native :: from_decimal ( 10_i128 )
161+ . and_then ( |t| t. pow_checked ( delta_scale as u32 ) . ok ( ) )
162+ } else {
163+ None
164+ } ;
165+
166+ let ( div_opt, half_opt, half_neg_opt) = if delta_scale < 0 {
167+ let div = I :: Native :: from_decimal ( 10_i128 )
168+ . and_then ( |t| t. pow_checked ( delta_scale. unsigned_abs ( ) as u32 ) . ok ( ) ) ;
169+ if let Some ( div) = div {
170+ let half = div. div_wrapping ( I :: Native :: from_usize ( 2 ) . unwrap ( ) ) ;
171+ let half_neg = half. neg_wrapping ( ) ;
172+ ( Some ( div) , Some ( half) , Some ( half_neg) )
173+ } else {
174+ ( None , None , None )
175+ }
176+ } else {
177+ ( None , None , None )
178+ } ;
179+
180+ move |x : I :: Native | {
181+ if delta_scale == 0 {
182+ return O :: Native :: from_decimal ( x) ;
183+ }
184+
185+ if let Some ( mul) = mul_opt {
186+ return O :: Native :: from_decimal ( x) . and_then ( |x| x. mul_checked ( mul) . ok ( ) ) ;
187+ }
188+
189+ // Decrease scale path
190+ let div = div_opt. unwrap ( ) ;
191+ let half = half_opt. unwrap ( ) ;
192+ let half_neg = half_neg_opt. unwrap ( ) ;
193+
194+ // div is >= 10 and so this cannot overflow
195+ let d = x. div_wrapping ( div) ;
196+ let r = x. mod_wrapping ( div) ;
197+
198+ // Round result
199+ let adjusted = match x >= I :: Native :: ZERO {
200+ true if r >= half => d. add_wrapping ( I :: Native :: ONE ) ,
201+ false if r <= half_neg => d. sub_wrapping ( I :: Native :: ONE ) ,
202+ _ => d,
203+ } ;
204+ O :: Native :: from_decimal ( adjusted)
205+ }
206+ }
207+
142208pub ( crate ) fn cast_decimal_to_decimal_error < I , O > (
143209 output_precision : u8 ,
144210 output_scale : i8 ,
@@ -188,26 +254,12 @@ where
188254 // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
189255 let is_infallible_cast = ( input_precision as i8 ) - delta_scale < ( output_precision as i8 ) ;
190256
191- let div = I :: Native :: from_decimal ( 10_i128 )
192- . unwrap ( )
193- . pow_checked ( delta_scale as u32 ) ?;
194-
195- let half = div. div_wrapping ( I :: Native :: from_usize ( 2 ) . unwrap ( ) ) ;
196- let half_neg = half. neg_wrapping ( ) ;
197-
198- let f = |x : I :: Native | {
199- // div is >= 10 and so this cannot overflow
200- let d = x. div_wrapping ( div) ;
201- let r = x. mod_wrapping ( div) ;
202-
203- // Round result
204- let adjusted = match x >= I :: Native :: ZERO {
205- true if r >= half => d. add_wrapping ( I :: Native :: ONE ) ,
206- false if r <= half_neg => d. sub_wrapping ( I :: Native :: ONE ) ,
207- _ => d,
208- } ;
209- O :: Native :: from_decimal ( adjusted)
210- } ;
257+ let f = rescale_decimal :: < I , O > (
258+ input_precision,
259+ input_scale,
260+ output_precision,
261+ output_scale,
262+ ) ;
211263
212264 Ok ( if is_infallible_cast {
213265 // make sure we don't perform calculations that don't make sense w/o validation
@@ -242,9 +294,6 @@ where
242294{
243295 let error = cast_decimal_to_decimal_error :: < I , O > ( output_precision, output_scale) ;
244296 let delta_scale = output_scale - input_scale;
245- let mul = O :: Native :: from_decimal ( 10_i128 )
246- . unwrap ( )
247- . pow_checked ( delta_scale as u32 ) ?;
248297
249298 // if the gain in precision (digits) is greater than the multiplication due to scaling
250299 // every number will fit into the output type
@@ -253,13 +302,18 @@ where
253302 // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
254303 // needs to provide at least 8 digits precision
255304 let is_infallible_cast = ( input_precision as i8 ) + delta_scale <= ( output_precision as i8 ) ;
256- let f = |x| O :: Native :: from_decimal ( x) . and_then ( |x| x. mul_checked ( mul) . ok ( ) ) ;
305+ let f = rescale_decimal :: < I , O > (
306+ input_precision,
307+ input_scale,
308+ output_precision,
309+ output_scale,
310+ ) ;
257311
258312 Ok ( if is_infallible_cast {
259313 // make sure we don't perform calculations that don't make sense w/o validation
260314 validate_decimal_precision_and_scale :: < O > ( output_precision, output_scale) ?;
261315 // unwrapping is safe since the result is guaranteed to fit into the target type
262- let f = |x| O :: Native :: from_decimal ( x) . unwrap ( ) . mul_wrapping ( mul ) ;
316+ let f = |x : I :: Native | f ( x) . unwrap ( ) ;
263317 array. unary ( f)
264318 } else if cast_options. safe {
265319 array. unary_opt ( |x| f ( x) . filter ( |v| O :: is_valid_decimal_precision ( * v, output_precision) ) )
0 commit comments