@@ -74,12 +74,13 @@ use arrow::compute::kernels::numeric::{
7474 add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
7575} ;
7676use arrow:: datatypes:: {
77- i256, ArrowDictionaryKeyType , ArrowNativeType , ArrowTimestampType , DataType ,
78- Date32Type , Field , Float32Type , Int16Type , Int32Type , Int64Type , Int8Type ,
79- IntervalDayTime , IntervalDayTimeType , IntervalMonthDayNano , IntervalMonthDayNanoType ,
80- IntervalUnit , IntervalYearMonthType , TimeUnit , TimestampMicrosecondType ,
81- TimestampMillisecondType , TimestampNanosecondType , TimestampSecondType , UInt16Type ,
82- UInt32Type , UInt64Type , UInt8Type , UnionFields , UnionMode , DECIMAL128_MAX_PRECISION ,
77+ i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType , ArrowNativeType ,
78+ ArrowTimestampType , DataType , Date32Type , Decimal128Type , Decimal256Type , Field ,
79+ Float32Type , Int16Type , Int32Type , Int64Type , Int8Type , IntervalDayTime ,
80+ IntervalDayTimeType , IntervalMonthDayNano , IntervalMonthDayNanoType , IntervalUnit ,
81+ IntervalYearMonthType , TimeUnit , TimestampMicrosecondType , TimestampMillisecondType ,
82+ TimestampNanosecondType , TimestampSecondType , UInt16Type , UInt32Type , UInt64Type ,
83+ UInt8Type , UnionFields , UnionMode , DECIMAL128_MAX_PRECISION ,
8384} ;
8485use arrow:: util:: display:: { array_value_to_string, ArrayFormatter , FormatOptions } ;
8586use cache:: { get_or_create_cached_key_array, get_or_create_cached_null_array} ;
@@ -1516,6 +1517,34 @@ impl ScalarValue {
15161517 DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( 1.0 ) ) ) ,
15171518 DataType :: Float32 => ScalarValue :: Float32 ( Some ( 1.0 ) ) ,
15181519 DataType :: Float64 => ScalarValue :: Float64 ( Some ( 1.0 ) ) ,
1520+ DataType :: Decimal128 ( precision, scale) => {
1521+ validate_decimal_precision_and_scale :: < Decimal128Type > (
1522+ * precision, * scale,
1523+ ) ?;
1524+ if * scale < 0 {
1525+ return _internal_err ! ( "Negative scale is not supported" ) ;
1526+ }
1527+ match i128:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1528+ Some ( value) => {
1529+ ScalarValue :: Decimal128 ( Some ( value) , * precision, * scale)
1530+ }
1531+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1532+ }
1533+ }
1534+ DataType :: Decimal256 ( precision, scale) => {
1535+ validate_decimal_precision_and_scale :: < Decimal256Type > (
1536+ * precision, * scale,
1537+ ) ?;
1538+ if * scale < 0 {
1539+ return _internal_err ! ( "Negative scale is not supported" ) ;
1540+ }
1541+ match i256:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1542+ Some ( value) => {
1543+ ScalarValue :: Decimal256 ( Some ( value) , * precision, * scale)
1544+ }
1545+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1546+ }
1547+ }
15191548 _ => {
15201549 return _not_impl_err ! (
15211550 "Can't create an one scalar from data_type \" {datatype:?}\" "
@@ -1534,6 +1563,34 @@ impl ScalarValue {
15341563 DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( -1.0 ) ) ) ,
15351564 DataType :: Float32 => ScalarValue :: Float32 ( Some ( -1.0 ) ) ,
15361565 DataType :: Float64 => ScalarValue :: Float64 ( Some ( -1.0 ) ) ,
1566+ DataType :: Decimal128 ( precision, scale) => {
1567+ validate_decimal_precision_and_scale :: < Decimal128Type > (
1568+ * precision, * scale,
1569+ ) ?;
1570+ if * scale < 0 {
1571+ return _internal_err ! ( "Negative scale is not supported" ) ;
1572+ }
1573+ match i128:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1574+ Some ( value) => {
1575+ ScalarValue :: Decimal128 ( Some ( -value) , * precision, * scale)
1576+ }
1577+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1578+ }
1579+ }
1580+ DataType :: Decimal256 ( precision, scale) => {
1581+ validate_decimal_precision_and_scale :: < Decimal256Type > (
1582+ * precision, * scale,
1583+ ) ?;
1584+ if * scale < 0 {
1585+ return _internal_err ! ( "Negative scale is not supported" ) ;
1586+ }
1587+ match i256:: from ( 10 ) . checked_pow ( * scale as u32 ) {
1588+ Some ( value) => {
1589+ ScalarValue :: Decimal256 ( Some ( -value) , * precision, * scale)
1590+ }
1591+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1592+ }
1593+ }
15371594 _ => {
15381595 return _not_impl_err ! (
15391596 "Can't create a negative one scalar from data_type \" {datatype:?}\" "
@@ -1555,6 +1612,38 @@ impl ScalarValue {
15551612 DataType :: Float16 => ScalarValue :: Float16 ( Some ( f16:: from_f32 ( 10.0 ) ) ) ,
15561613 DataType :: Float32 => ScalarValue :: Float32 ( Some ( 10.0 ) ) ,
15571614 DataType :: Float64 => ScalarValue :: Float64 ( Some ( 10.0 ) ) ,
1615+ DataType :: Decimal128 ( precision, scale) => {
1616+ if let Err ( err) = validate_decimal_precision_and_scale :: < Decimal128Type > (
1617+ * precision, * scale,
1618+ ) {
1619+ return _internal_err ! ( "Invalid precision and scale {err}" ) ;
1620+ }
1621+ if * scale <= 0 {
1622+ return _internal_err ! ( "Negative scale is not supported" ) ;
1623+ }
1624+ match i128:: from ( 10 ) . checked_pow ( ( * scale + 1 ) as u32 ) {
1625+ Some ( value) => {
1626+ ScalarValue :: Decimal128 ( Some ( value) , * precision, * scale)
1627+ }
1628+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1629+ }
1630+ }
1631+ DataType :: Decimal256 ( precision, scale) => {
1632+ if let Err ( err) = validate_decimal_precision_and_scale :: < Decimal256Type > (
1633+ * precision, * scale,
1634+ ) {
1635+ return _internal_err ! ( "Invalid precision and scale {err}" ) ;
1636+ }
1637+ if * scale <= 0 {
1638+ return _internal_err ! ( "Negative scale is not supported" ) ;
1639+ }
1640+ match i256:: from ( 10 ) . checked_pow ( ( * scale + 1 ) as u32 ) {
1641+ Some ( value) => {
1642+ ScalarValue :: Decimal256 ( Some ( value) , * precision, * scale)
1643+ }
1644+ None => return _internal_err ! ( "Unsupported scale {scale}" ) ,
1645+ }
1646+ }
15581647 _ => {
15591648 return _not_impl_err ! (
15601649 "Can't create a ten scalar from data_type \" {datatype:?}\" "
@@ -1924,6 +2013,26 @@ impl ScalarValue {
19242013 ( Self :: Float64 ( Some ( l) ) , Self :: Float64 ( Some ( r) ) ) => {
19252014 Some ( ( l - r) . abs ( ) . round ( ) as _ )
19262015 }
2016+ (
2017+ Self :: Decimal128 ( Some ( l) , lprecision, lscale) ,
2018+ Self :: Decimal128 ( Some ( r) , rprecision, rscale) ,
2019+ ) => {
2020+ if lprecision == rprecision && lscale == rscale {
2021+ l. checked_sub ( * r) ?. checked_abs ( ) ?. to_usize ( )
2022+ } else {
2023+ None
2024+ }
2025+ }
2026+ (
2027+ Self :: Decimal256 ( Some ( l) , lprecision, lscale) ,
2028+ Self :: Decimal256 ( Some ( r) , rprecision, rscale) ,
2029+ ) => {
2030+ if lprecision == rprecision && lscale == rscale {
2031+ l. checked_sub ( * r) ?. checked_abs ( ) ?. to_usize ( )
2032+ } else {
2033+ None
2034+ }
2035+ }
19272036 _ => None ,
19282037 }
19292038 }
@@ -4489,7 +4598,9 @@ mod tests {
44894598 } ;
44904599 use arrow:: buffer:: { Buffer , OffsetBuffer } ;
44914600 use arrow:: compute:: { is_null, kernels} ;
4492- use arrow:: datatypes:: { ArrowNumericType , Fields , Float64Type } ;
4601+ use arrow:: datatypes:: {
4602+ ArrowNumericType , Fields , Float64Type , DECIMAL256_MAX_PRECISION ,
4603+ } ;
44934604 use arrow:: error:: ArrowError ;
44944605 use arrow:: util:: pretty:: pretty_format_columns;
44954606 use chrono:: NaiveDate ;
@@ -5225,6 +5336,116 @@ mod tests {
52255336 Ok ( ( ) )
52265337 }
52275338
5339+ #[ test]
5340+ fn test_new_one_decimal128 ( ) {
5341+ assert_eq ! (
5342+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 0 ) ) . unwrap( ) ,
5343+ ScalarValue :: Decimal128 ( Some ( 1 ) , 5 , 0 )
5344+ ) ;
5345+ assert_eq ! (
5346+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 1 ) ) . unwrap( ) ,
5347+ ScalarValue :: Decimal128 ( Some ( 10 ) , 5 , 1 )
5348+ ) ;
5349+ assert_eq ! (
5350+ ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5351+ ScalarValue :: Decimal128 ( Some ( 100 ) , 5 , 2 )
5352+ ) ;
5353+ // More precision
5354+ assert_eq ! (
5355+ ScalarValue :: new_one( & DataType :: Decimal128 ( 7 , 2 ) ) . unwrap( ) ,
5356+ ScalarValue :: Decimal128 ( Some ( 100 ) , 7 , 2 )
5357+ ) ;
5358+ // No negative scale
5359+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , -1 ) ) . is_err( ) ) ;
5360+ // Invalid combination
5361+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 0 , 2 ) ) . is_err( ) ) ;
5362+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal128 ( 5 , 7 ) ) . is_err( ) ) ;
5363+ }
5364+
5365+ #[ test]
5366+ fn test_new_one_decimal256 ( ) {
5367+ assert_eq ! (
5368+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 0 ) ) . unwrap( ) ,
5369+ ScalarValue :: Decimal256 ( Some ( 1 . into( ) ) , 5 , 0 )
5370+ ) ;
5371+ assert_eq ! (
5372+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 1 ) ) . unwrap( ) ,
5373+ ScalarValue :: Decimal256 ( Some ( 10 . into( ) ) , 5 , 1 )
5374+ ) ;
5375+ assert_eq ! (
5376+ ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 2 ) ) . unwrap( ) ,
5377+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 5 , 2 )
5378+ ) ;
5379+ // More precision
5380+ assert_eq ! (
5381+ ScalarValue :: new_one( & DataType :: Decimal256 ( 7 , 2 ) ) . unwrap( ) ,
5382+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 7 , 2 )
5383+ ) ;
5384+ // No negative scale
5385+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , -1 ) ) . is_err( ) ) ;
5386+ // Invalid combination
5387+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 0 , 2 ) ) . is_err( ) ) ;
5388+ assert ! ( ScalarValue :: new_one( & DataType :: Decimal256 ( 5 , 7 ) ) . is_err( ) ) ;
5389+ }
5390+
5391+ #[ test]
5392+ fn test_new_ten_decimal128 ( ) {
5393+ assert_eq ! (
5394+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 1 ) ) . unwrap( ) ,
5395+ ScalarValue :: Decimal128 ( Some ( 100 ) , 5 , 1 )
5396+ ) ;
5397+ assert_eq ! (
5398+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5399+ ScalarValue :: Decimal128 ( Some ( 1000 ) , 5 , 2 )
5400+ ) ;
5401+ // More precision
5402+ assert_eq ! (
5403+ ScalarValue :: new_ten( & DataType :: Decimal128 ( 7 , 2 ) ) . unwrap( ) ,
5404+ ScalarValue :: Decimal128 ( Some ( 1000 ) , 7 , 2 )
5405+ ) ;
5406+ // No negative or zero scale
5407+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 0 ) ) . is_err( ) ) ;
5408+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , -1 ) ) . is_err( ) ) ;
5409+ // Invalid combination
5410+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 0 , 2 ) ) . is_err( ) ) ;
5411+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal128 ( 5 , 7 ) ) . is_err( ) ) ;
5412+ }
5413+
5414+ #[ test]
5415+ fn test_new_ten_decimal256 ( ) {
5416+ assert_eq ! (
5417+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 1 ) ) . unwrap( ) ,
5418+ ScalarValue :: Decimal256 ( Some ( 100 . into( ) ) , 5 , 1 )
5419+ ) ;
5420+ assert_eq ! (
5421+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 2 ) ) . unwrap( ) ,
5422+ ScalarValue :: Decimal256 ( Some ( 1000 . into( ) ) , 5 , 2 )
5423+ ) ;
5424+ // More precision
5425+ assert_eq ! (
5426+ ScalarValue :: new_ten( & DataType :: Decimal256 ( 7 , 2 ) ) . unwrap( ) ,
5427+ ScalarValue :: Decimal256 ( Some ( 1000 . into( ) ) , 7 , 2 )
5428+ ) ;
5429+ // No negative or zero scale
5430+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 0 ) ) . is_err( ) ) ;
5431+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , -1 ) ) . is_err( ) ) ;
5432+ // Invalid combination
5433+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 0 , 2 ) ) . is_err( ) ) ;
5434+ assert ! ( ScalarValue :: new_ten( & DataType :: Decimal256 ( 5 , 7 ) ) . is_err( ) ) ;
5435+ }
5436+
5437+ #[ test]
5438+ fn test_new_negative_one_decimal128 ( ) {
5439+ assert_eq ! (
5440+ ScalarValue :: new_negative_one( & DataType :: Decimal128 ( 5 , 0 ) ) . unwrap( ) ,
5441+ ScalarValue :: Decimal128 ( Some ( -1 ) , 5 , 0 )
5442+ ) ;
5443+ assert_eq ! (
5444+ ScalarValue :: new_negative_one( & DataType :: Decimal128 ( 5 , 2 ) ) . unwrap( ) ,
5445+ ScalarValue :: Decimal128 ( Some ( -100 ) , 5 , 2 )
5446+ ) ;
5447+ }
5448+
52285449 #[ test]
52295450 fn test_list_partial_cmp ( ) {
52305451 let a =
@@ -7275,13 +7496,51 @@ mod tests {
72757496 ScalarValue :: Float64 ( Some ( -9.9 ) ) ,
72767497 5 ,
72777498 ) ,
7499+ (
7500+ ScalarValue :: Decimal128 ( Some ( 10 ) , 1 , 0 ) ,
7501+ ScalarValue :: Decimal128 ( Some ( 5 ) , 1 , 0 ) ,
7502+ 5 ,
7503+ ) ,
7504+ (
7505+ ScalarValue :: Decimal128 ( Some ( 5 ) , 1 , 0 ) ,
7506+ ScalarValue :: Decimal128 ( Some ( 10 ) , 1 , 0 ) ,
7507+ 5 ,
7508+ ) ,
7509+ (
7510+ ScalarValue :: Decimal256 ( Some ( 10 . into ( ) ) , 1 , 0 ) ,
7511+ ScalarValue :: Decimal256 ( Some ( 5 . into ( ) ) , 1 , 0 ) ,
7512+ 5 ,
7513+ ) ,
7514+ (
7515+ ScalarValue :: Decimal256 ( Some ( 5 . into ( ) ) , 1 , 0 ) ,
7516+ ScalarValue :: Decimal256 ( Some ( 10 . into ( ) ) , 1 , 0 ) ,
7517+ 5 ,
7518+ ) ,
72787519 ] ;
72797520 for ( lhs, rhs, expected) in cases. iter ( ) {
72807521 let distance = lhs. distance ( rhs) . unwrap ( ) ;
72817522 assert_eq ! ( distance, * expected) ;
72827523 }
72837524 }
72847525
7526+ #[ test]
7527+ fn test_distance_none ( ) {
7528+ let cases = [
7529+ (
7530+ ScalarValue :: Decimal128 ( Some ( i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ,
7531+ ScalarValue :: Decimal128 ( Some ( -i128:: MAX ) , DECIMAL128_MAX_PRECISION , 0 ) ,
7532+ ) ,
7533+ (
7534+ ScalarValue :: Decimal256 ( Some ( i256:: MAX ) , DECIMAL256_MAX_PRECISION , 0 ) ,
7535+ ScalarValue :: Decimal256 ( Some ( -i256:: MAX ) , DECIMAL256_MAX_PRECISION , 0 ) ,
7536+ ) ,
7537+ ] ;
7538+ for ( lhs, rhs) in cases. iter ( ) {
7539+ let distance = lhs. distance ( rhs) ;
7540+ assert ! ( distance. is_none( ) , "{lhs} vs {rhs}" ) ;
7541+ }
7542+ }
7543+
72857544 #[ test]
72867545 fn test_scalar_distance_invalid ( ) {
72877546 let cases = [
@@ -7323,7 +7582,33 @@ mod tests {
73237582 ( ScalarValue :: Date64 ( Some ( 0 ) ) , ScalarValue :: Date64 ( Some ( 1 ) ) ) ,
73247583 (
73257584 ScalarValue :: Decimal128 ( Some ( 123 ) , 5 , 5 ) ,
7326- ScalarValue :: Decimal128 ( Some ( 120 ) , 5 , 5 ) ,
7585+ ScalarValue :: Decimal128 ( Some ( 120 ) , 5 , 3 ) ,
7586+ ) ,
7587+ (
7588+ ScalarValue :: Decimal128 ( Some ( 123 ) , 5 , 5 ) ,
7589+ ScalarValue :: Decimal128 ( Some ( 120 ) , 3 , 5 ) ,
7590+ ) ,
7591+ (
7592+ ScalarValue :: Decimal256 ( Some ( 123 . into ( ) ) , 5 , 5 ) ,
7593+ ScalarValue :: Decimal256 ( Some ( 120 . into ( ) ) , 3 , 5 ) ,
7594+ ) ,
7595+ // Distance 2 * 2^50 is larger than usize
7596+ (
7597+ ScalarValue :: Decimal256 (
7598+ Some ( i256:: from_parts ( 0 , 2_i64 . pow ( 50 ) . into ( ) ) ) ,
7599+ 1 ,
7600+ 0 ,
7601+ ) ,
7602+ ScalarValue :: Decimal256 (
7603+ Some ( i256:: from_parts ( 0 , ( -( 2_i64 ) . pow ( 50 ) ) . into ( ) ) ) ,
7604+ 1 ,
7605+ 0 ,
7606+ ) ,
7607+ ) ,
7608+ // Distance overflow
7609+ (
7610+ ScalarValue :: Decimal256 ( Some ( i256:: from_parts ( 0 , i128:: MAX ) ) , 1 , 0 ) ,
7611+ ScalarValue :: Decimal256 ( Some ( i256:: from_parts ( 0 , -i128:: MAX ) ) , 1 , 0 ) ,
73277612 ) ,
73287613 ] ;
73297614 for ( lhs, rhs) in cases {
0 commit comments