Skip to content

Commit 56190ef

Browse files
author
Davies Liu
committed
simplify Decimal
1 parent 571a8a3 commit 56190ef

File tree

4 files changed

+57
-183
lines changed

4 files changed

+57
-183
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,16 @@ case class Cast(child: Expression, dataType: DataType)
562562
java.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
563563
${changePrecision("tmpDecimal", target, evPrim, evNull)}
564564
"""
565-
case DecimalType() =>
565+
case dt: DecimalType =>
566566
(c, evPrim, evNull) =>
567567
s"""
568568
Decimal tmpDecimal = $c.clone();
569569
${changePrecision("tmpDecimal", target, evPrim, evNull)}
570570
"""
571-
case LongType =>
571+
case ByteType | ShortType | IntegerType | LongType =>
572572
(c, evPrim, evNull) =>
573573
s"""
574-
Decimal tmpDecimal = Decimal.apply($c);
574+
Decimal tmpDecimal = Decimal.apply((long) $c);
575575
${changePrecision("tmpDecimal", target, evPrim, evNull)}
576576
"""
577577
case x: NumericType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 32 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
2020
import java.math.{MathContext, RoundingMode, BigDecimal => JavaBigDecimal}
2121

2222
import org.apache.spark.annotation.DeveloperApi
23+
import org.apache.spark.unsafe.PlatformDependent
2324

2425
/**
2526
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -32,74 +33,38 @@ import org.apache.spark.annotation.DeveloperApi
3233
final class Decimal extends Ordered[Decimal] with Serializable {
3334
import org.apache.spark.sql.types.Decimal._
3435

35-
private var decimalVal: JavaBigDecimal = null
36-
private var longVal: Long = 0L
36+
private var decimalVal: JavaBigDecimal = BIG_DEC_ZERO
3737
private var _precision: Int = 1
38-
private var _scale: Int = 0
3938

4039
def precision: Int = _precision
41-
def scale: Int = _scale
40+
def scale: Int = decimalVal.scale()
4241

4342
/**
4443
* Set this Decimal to the given Long. Will have precision 20 and scale 0.
4544
*/
4645
def set(longVal: Long): Decimal = {
47-
if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) {
48-
// We can't represent this compactly as a long without risking overflow
49-
this.decimalVal = new JavaBigDecimal(longVal)
50-
this.longVal = 0L
51-
} else {
52-
this.decimalVal = null
53-
this.longVal = longVal
54-
}
55-
this._precision = 20
56-
this._scale = 0
46+
decimalVal = JavaBigDecimal.valueOf(longVal)
47+
_precision = 20
5748
this
5849
}
5950

6051
/**
6152
* Set this Decimal to the given Int. Will have precision 10 and scale 0.
6253
*/
6354
def set(intVal: Int): Decimal = {
64-
this.decimalVal = null
65-
this.longVal = intVal
66-
this._precision = 10
67-
this._scale = 0
55+
decimalVal = JavaBigDecimal.valueOf(intVal)
56+
_precision = 10
6857
this
6958
}
7059

7160
/**
7261
* Set this Decimal to the given unscaled Long, with a given precision and scale.
62+
*
63+
* Note: this is used in serialization, caller will make sure that it will not overflow
7364
*/
7465
def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
75-
if (setOrNull(unscaled, precision, scale) == null) {
76-
throw new IllegalArgumentException("Unscaled value too large for precision")
77-
}
78-
this
79-
}
80-
81-
/**
82-
* Set this Decimal to the given unscaled Long, with a given precision and scale,
83-
* and return it, or return null if it cannot be set due to overflow.
84-
*/
85-
def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
86-
if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
87-
// We can't represent this compactly as a long without risking overflow
88-
if (precision < 19) {
89-
return null // Requested precision is too low to represent this value
90-
}
91-
this.decimalVal = new JavaBigDecimal(unscaled)
92-
this.longVal = 0L
93-
} else {
94-
val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
95-
if (unscaled <= -p || unscaled >= p) {
96-
return null // Requested precision is too low to represent this value
97-
}
98-
this.decimalVal = null
99-
this.longVal = unscaled
100-
}
101-
this._precision = precision
102-
this._scale = scale
66+
decimalVal = JavaBigDecimal.valueOf(unscaled, scale)
67+
_precision = precision
10368
this
10469
}
10570

@@ -121,11 +86,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
12186
* Set this Decimal to the given java.math.BigDecimal value, with a given precision and scale.
12287
*/
12388
private[sql] def set(decimal: JavaBigDecimal, precision: Int, scale: Int): Decimal = {
124-
this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
89+
decimalVal = decimal.setScale(scale, ROUNDING_MODE)
12590
require(decimalVal.precision <= precision, "Overflowed precision")
126-
this.longVal = 0L
127-
this._precision = precision
128-
this._scale = scale
91+
_precision = precision
12992
this
13093
}
13194

@@ -134,9 +97,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
13497
*/
13598
private[sql] def set(decimal: JavaBigDecimal): Decimal = {
13699
this.decimalVal = decimal
137-
this.longVal = 0L
138100
this._precision = decimal.precision
139-
this._scale = decimal.scale
140101
this
141102
}
142103

@@ -145,52 +106,36 @@ final class Decimal extends Ordered[Decimal] with Serializable {
145106
*/
146107
def set(decimal: Decimal): Decimal = {
147108
this.decimalVal = decimal.decimalVal
148-
this.longVal = decimal.longVal
149109
this._precision = decimal._precision
150-
this._scale = decimal._scale
151110
this
152111
}
153112

154113
def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal)
155114

156-
private[sql] def toJavaBigDecimal: JavaBigDecimal = {
157-
if (decimalVal.ne(null)) {
158-
decimalVal
159-
} else {
160-
JavaBigDecimal.valueOf(longVal, _scale)
161-
}
162-
}
115+
private[sql] def toJavaBigDecimal: JavaBigDecimal = decimalVal
163116

164117
def toUnscaledLong: Long = {
165-
if (decimalVal.ne(null)) {
166-
decimalVal.unscaledValue().longValue()
118+
val unscaled = PlatformDependent.UNSAFE.getLong(decimalVal,
119+
PlatformDependent.BIG_DECIMAL_INTCOMPACT_OFFSET)
120+
if (unscaled != Long.MinValue) {
121+
unscaled
167122
} else {
168-
longVal
123+
decimalVal.unscaledValue().longValue()
169124
}
170125
}
171126

172-
override def toString: String = toJavaBigDecimal.toString()
127+
override def toString: String = decimalVal.toString()
173128

174129
@DeveloperApi
175130
def toDebugString: String = {
176-
if (decimalVal.ne(null)) {
177-
s"Decimal(expanded,$decimalVal,$precision,$scale})"
178-
} else {
179-
s"Decimal(compact,$longVal,$precision,$scale})"
180-
}
131+
s"Decimal($decimalVal,${_precision})"
181132
}
182133

183134
def toDouble: Double = toJavaBigDecimal.doubleValue()
184135

185136
def toFloat: Float = toJavaBigDecimal.floatValue()
186137

187-
def toLong: Long = {
188-
if (decimalVal.eq(null)) {
189-
longVal / POW_10(_scale)
190-
} else {
191-
decimalVal.longValue()
192-
}
193-
}
138+
def toLong: Long = decimalVal.longValue()
194139

195140
def toInt: Int = toLong.toInt
196141

@@ -205,65 +150,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
205150
*/
206151
def changePrecision(precision: Int, scale: Int): Boolean = {
207152
// fast path for UnsafeProjection
208-
if (precision == this.precision && scale == this.scale) {
153+
if (precision == _precision && scale == decimalVal.scale()) {
209154
return true
210155
}
211-
// First, update our longVal if we can, or transfer over to using a BigDecimal
212-
if (decimalVal.eq(null)) {
213-
if (scale < _scale) {
214-
// Easier case: we just need to divide our scale down
215-
val diff = _scale - scale
216-
val droppedDigits = longVal % POW_10(diff)
217-
longVal /= POW_10(diff)
218-
if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
219-
longVal += (if (longVal < 0) -1L else 1L)
220-
}
221-
} else if (scale > _scale) {
222-
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,
223-
// switch to using a BigDecimal
224-
val diff = scale - _scale
225-
val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0))
226-
if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) {
227-
// Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS
228-
longVal *= POW_10(diff)
229-
} else {
230-
// Give up on using Longs; switch to BigDecimal, which we'll modify below
231-
decimalVal = JavaBigDecimal.valueOf(longVal, _scale)
232-
}
233-
}
234-
// In both cases, we will check whether our precision is okay below
235-
}
236156

237-
if (decimalVal.ne(null)) {
238-
// We get here if either we started with a BigDecimal, or we switched to one because we would
239-
// have overflowed our Long; in either case we must rescale decimalVal to the new scale.
240-
val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
241-
if (newVal.precision > precision) {
242-
return false
243-
}
244-
decimalVal = newVal
245-
} else {
246-
// We're still using Longs, but we should check whether we match the new precision
247-
val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
248-
if (longVal <= -p || longVal >= p) {
249-
// Note that we shouldn't have been able to fix this by switching to BigDecimal
250-
return false
251-
}
157+
val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
158+
if (newVal.precision > precision) {
159+
return false
252160
}
253-
161+
decimalVal = newVal
254162
_precision = precision
255-
_scale = scale
256163
true
257164
}
258165

259166
override def clone(): Decimal = new Decimal().set(this)
260167

261168
override def compare(other: Decimal): Int = {
262-
if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) {
263-
if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1
264-
} else {
265-
toJavaBigDecimal.compareTo(other.toJavaBigDecimal)
266-
}
169+
toJavaBigDecimal.compareTo(other.toJavaBigDecimal)
267170
}
268171

269172
override def equals(other: Any): Boolean = other match {
@@ -276,24 +179,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
276179
override def hashCode(): Int = toBigDecimal.hashCode()
277180

278181
def isZero: Boolean = {
279-
if (decimalVal.ne(null)) decimalVal.compareTo(BIG_DEC_ZERO) == 0
280-
else longVal == 0
182+
decimalVal.compareTo(BIG_DEC_ZERO) == 0
281183
}
282184

283185
def + (that: Decimal): Decimal = {
284-
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
285-
Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale)
286-
} else {
287-
Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale)
288-
}
186+
Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale)
289187
}
290188

291189
def - (that: Decimal): Decimal = {
292-
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
293-
Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale)
294-
} else {
295-
Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale)
296-
}
190+
Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT), precision, scale)
297191
}
298192

299193
// HiveTypeCoercion will take care of the precision, scale of result
@@ -313,11 +207,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
313207
def remainder(that: Decimal): Decimal = this % that
314208

315209
def unary_- : Decimal = {
316-
if (decimalVal.ne(null)) {
317-
Decimal(decimalVal.negate(), precision, scale)
318-
} else {
319-
Decimal(-longVal, precision, scale)
320-
}
210+
Decimal(decimalVal.negate(), precision, scale)
321211
}
322212

323213
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this

sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
4646
checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0)
4747
checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
4848
checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
49-
intercept[IllegalArgumentException](Decimal(170L, 2, 1))
50-
intercept[IllegalArgumentException](Decimal(170L, 2, 0))
5149
intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
5250
intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
53-
intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
5451
}
5552

5653
test("creating decimals with negative scale") {
@@ -88,36 +85,19 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
8885
checkValues(Decimal(Double.MinValue), Double.MinValue, 0L)
8986
}
9087

91-
// Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs
92-
private val decimalVal = PrivateMethod[BigDecimal]('decimalVal)
93-
94-
/** Check whether a decimal is represented compactly (passing whether we expect it to be) */
95-
private def checkCompact(d: Decimal, expected: Boolean): Unit = {
96-
val isCompact = d.invokePrivate(decimalVal()).eq(null)
97-
assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact")
98-
}
99-
100-
test("small decimals represented as unscaled long") {
101-
checkCompact(new Decimal(), true)
102-
checkCompact(Decimal(BigDecimal(10.03)), false)
103-
checkCompact(Decimal(BigDecimal(1e20)), false)
104-
checkCompact(Decimal(17L), true)
105-
checkCompact(Decimal(17), true)
106-
checkCompact(Decimal(17L, 2, 1), true)
107-
checkCompact(Decimal(170L, 4, 2), true)
108-
checkCompact(Decimal(17L, 24, 1), true)
109-
checkCompact(Decimal(1e16.toLong), true)
110-
checkCompact(Decimal(1e17.toLong), true)
111-
checkCompact(Decimal(1e18.toLong - 1), true)
112-
checkCompact(Decimal(- 1e18.toLong + 1), true)
113-
checkCompact(Decimal(1e18.toLong - 1, 30, 10), true)
114-
checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true)
115-
checkCompact(Decimal(1e18.toLong), false)
116-
checkCompact(Decimal(-1e18.toLong), false)
117-
checkCompact(Decimal(1e18.toLong, 30, 10), false)
118-
checkCompact(Decimal(-1e18.toLong, 30, 10), false)
119-
checkCompact(Decimal(Long.MaxValue), false)
120-
checkCompact(Decimal(Long.MinValue), false)
88+
test("change precision and scale") {
89+
assert(true === Decimal(5).changePrecision(1, 0))
90+
assert(false === Decimal(15).changePrecision(1, 0))
91+
assert(true === Decimal(5).changePrecision(2, 1))
92+
assert(false === Decimal(5).changePrecision(2, 2))
93+
assert(true === Decimal(0).changePrecision(1, 0))
94+
assert(true === Decimal(BigDecimal("10.5")).changePrecision(3, 0))
95+
assert(true === Decimal(BigDecimal("10.5")).changePrecision(3, 1))
96+
assert(false === Decimal(BigDecimal("10.5")).changePrecision(3, 2))
97+
assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 0))
98+
assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 1))
99+
assert(true === Decimal(BigDecimal("10.5")).changePrecision(4, 2))
100+
assert(false === Decimal(BigDecimal("10.5")).changePrecision(4, 3))
121101
}
122102

123103
test("hash code") {
@@ -132,10 +112,6 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
132112
}
133113

134114
test("equals") {
135-
// The decimals on the left are stored compactly, while the ones on the right aren't
136-
checkCompact(Decimal(123), true)
137-
checkCompact(Decimal(BigDecimal(123)), false)
138-
checkCompact(Decimal("123"), false)
139115
assert(Decimal(123) === Decimal(BigDecimal(123)))
140116
assert(Decimal(123) === Decimal(BigDecimal("123.00")))
141117
assert(Decimal(-123) === Decimal(BigDecimal(-123)))
@@ -187,7 +163,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
187163
assert(b.toDouble === 0.125)
188164
}
189165

190-
test("set/setOrNull") {
166+
test("set") {
191167
assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L)
192168
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
193169
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)

0 commit comments

Comments
 (0)