From a0007b4e98acf683069e360400742ee6934968cc Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Fri, 30 Oct 2020 11:27:49 +0100 Subject: [PATCH] Fixing BigDecimal conversion for PostgreSQL Now working properly with numbers, such as `0.01` and `0.012`. --- sqlx-core/src/postgres/types/bigdecimal.rs | 77 +++++++++++----------- tests/postgres/types.rs | 7 ++ 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/sqlx-core/src/postgres/types/bigdecimal.rs b/sqlx-core/src/postgres/types/bigdecimal.rs index 28b2603c0a..617c19c5fe 100644 --- a/sqlx-core/src/postgres/types/bigdecimal.rs +++ b/sqlx-core/src/postgres/types/bigdecimal.rs @@ -1,7 +1,7 @@ use std::cmp; use std::convert::{TryFrom, TryInto}; -use bigdecimal::BigDecimal; +use bigdecimal::{BigDecimal, ToPrimitive, Zero}; use num_bigint::{BigInt, Sign}; use crate::decode::Decode; @@ -77,65 +77,64 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &BigDecimal) -> Result { - let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16); + if decimal.is_zero() { + return Ok(PgNumeric::Number { + sign: PgNumericSign::Positive, + scale: 0, + weight: 0, + digits: vec![], + }); + } // NOTE: this unfortunately copies the BigInt internally let (integer, exp) = decimal.as_bigint_and_exponent(); - // this routine is specifically optimized for base-10 - // FIXME: is there a way to iterate over the digits to avoid the Vec allocation - let (sign, base_10) = integer.to_radix_be(10); - - // weight is positive power of 10000 - // exp is the negative power of 10 - let weight_10 = base_10.len() as i64 - exp; - // scale is only nonzero when we have fractional digits // since `exp` is the _negative_ decimal exponent, it tells us // exactly what our scale should be let scale: i16 = cmp::max(0, exp).try_into()?; - // there's an implicit +1 offset in the interpretation - let weight: i16 = if weight_10 <= 0 { - weight_10 / 4 - 1 - } else { - // the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight) - (weight_10 - 1) / 4 - } - .try_into()?; + let (sign, uint) = integer.into_parts(); + let mut mantissa = uint.to_u128().unwrap(); - let digits_len = if base_10.len() % 4 != 0 { - base_10.len() / 4 + 1 - } else { - base_10.len() / 4 - }; + // If our scale is not a multiple of 4, we need to go to the next + // multiple. + let groups_diff = scale % 4; + if groups_diff > 0 { + let remainder = 4 - groups_diff as u32; + let power = 10u32.pow(remainder as u32) as u128; - let offset = weight_10.rem_euclid(4) as usize; + mantissa = mantissa * power; + } - let mut digits = Vec::with_capacity(digits_len); + // Array to store max mantissa of Decimal in Postgres decimal format. + let mut digits = Vec::with_capacity(8); - if let Some(first) = base_10.get(..offset) { - if offset != 0 { - digits.push(base_10_to_10000(first)); - } + // Convert to base-10000. + while mantissa != 0 { + digits.push((mantissa % 10_000) as i16); + mantissa /= 10_000; } - if let Some(rest) = base_10.get(offset..) { - digits.extend( - rest.chunks(4) - .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), - ); - } + // Change the endianness. + digits.reverse(); + + // Weight is number of digits on the left side of the decimal. + let digits_after_decimal = (scale + 3) as u16 / 4; + let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; + // Remove non-significant zeroes. while let Some(&0) = digits.last() { digits.pop(); } + let sign = match sign { + Sign::Plus | Sign::NoSign => PgNumericSign::Positive, + Sign::Minus => PgNumericSign::Negative, + }; + Ok(PgNumeric::Number { - sign: match sign { - Sign::Plus | Sign::NoSign => PgNumericSign::Positive, - Sign::Minus => PgNumericSign::Negative, - }, + sign, scale, weight, digits, diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 93ac39761b..a0aa64eb69 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -396,7 +396,14 @@ test_type!(bigdecimal(Postgres, "10000::numeric" == "10000".parse::().unwrap(), "0.1::numeric" == "0.1".parse::().unwrap(), "0.01::numeric" == "0.01".parse::().unwrap(), + "0.012::numeric" == "0.012".parse::().unwrap(), + "0.0123::numeric" == "0.0123".parse::().unwrap(), "0.01234::numeric" == "0.01234".parse::().unwrap(), + "0.012345::numeric" == "0.012345".parse::().unwrap(), + "0.0123456::numeric" == "0.0123456".parse::().unwrap(), + "0.01234567::numeric" == "0.01234567".parse::().unwrap(), + "0.012345678::numeric" == "0.012345678".parse::().unwrap(), + "0.0123456789::numeric" == "0.0123456789".parse::().unwrap(), "12.34::numeric" == "12.34".parse::().unwrap(), "12345.6789::numeric" == "12345.6789".parse::().unwrap(), ));