@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
2020import java .math .{MathContext , RoundingMode , BigDecimal => JavaBigDecimal }
2121
2222import org .apache .spark .annotation .DeveloperApi
23+ import org .apache .spark .unsafe .PlatformDependent
2324
2425/**
2526 * A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -32,74 +33,38 @@ import org.apache.spark.annotation.DeveloperApi
3233final class Decimal extends Ordered [Decimal ] with Serializable {
3334 import org .apache .spark .sql .types .Decimal ._
3435
35- private var decimalVal : JavaBigDecimal = null
36- private var longVal : Long = 0L
36+ private var decimalVal : JavaBigDecimal = BIG_DEC_ZERO
3737 private var _precision : Int = 1
38- private var _scale : Int = 0
3938
4039 def precision : Int = _precision
41- def scale : Int = _scale
40+ def scale : Int = decimalVal.scale()
4241
4342 /**
4443 * Set this Decimal to the given Long. Will have precision 20 and scale 0.
4544 */
4645 def set (longVal : Long ): Decimal = {
47- if (longVal <= - POW_10 (MAX_LONG_DIGITS ) || longVal >= POW_10 (MAX_LONG_DIGITS )) {
48- // We can't represent this compactly as a long without risking overflow
49- this .decimalVal = new JavaBigDecimal (longVal)
50- this .longVal = 0L
51- } else {
52- this .decimalVal = null
53- this .longVal = longVal
54- }
55- this ._precision = 20
56- this ._scale = 0
46+ decimalVal = JavaBigDecimal .valueOf(longVal)
47+ _precision = 20
5748 this
5849 }
5950
6051 /**
6152 * Set this Decimal to the given Int. Will have precision 10 and scale 0.
6253 */
6354 def set (intVal : Int ): Decimal = {
64- this .decimalVal = null
65- this .longVal = intVal
66- this ._precision = 10
67- this ._scale = 0
55+ decimalVal = JavaBigDecimal .valueOf(intVal)
56+ _precision = 10
6857 this
6958 }
7059
7160 /**
7261 * Set this Decimal to the given unscaled Long, with a given precision and scale.
62+ *
63+ * Note: this is used in serialization, caller will make sure that it will not overflow
7364 */
7465 def set (unscaled : Long , precision : Int , scale : Int ): Decimal = {
75- if (setOrNull(unscaled, precision, scale) == null ) {
76- throw new IllegalArgumentException (" Unscaled value too large for precision" )
77- }
78- this
79- }
80-
81- /**
82- * Set this Decimal to the given unscaled Long, with a given precision and scale,
83- * and return it, or return null if it cannot be set due to overflow.
84- */
85- def setOrNull (unscaled : Long , precision : Int , scale : Int ): Decimal = {
86- if (unscaled <= - POW_10 (MAX_LONG_DIGITS ) || unscaled >= POW_10 (MAX_LONG_DIGITS )) {
87- // We can't represent this compactly as a long without risking overflow
88- if (precision < 19 ) {
89- return null // Requested precision is too low to represent this value
90- }
91- this .decimalVal = new JavaBigDecimal (unscaled)
92- this .longVal = 0L
93- } else {
94- val p = POW_10 (math.min(precision, MAX_LONG_DIGITS ))
95- if (unscaled <= - p || unscaled >= p) {
96- return null // Requested precision is too low to represent this value
97- }
98- this .decimalVal = null
99- this .longVal = unscaled
100- }
101- this ._precision = precision
102- this ._scale = scale
66+ decimalVal = JavaBigDecimal .valueOf(unscaled, scale)
67+ _precision = precision
10368 this
10469 }
10570
@@ -121,11 +86,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
12186 * Set this Decimal to the given java.math.BigDecimal value, with a given precision and scale.
12287 */
12388 private [sql] def set (decimal : JavaBigDecimal , precision : Int , scale : Int ): Decimal = {
124- this . decimalVal = decimal.setScale(scale, ROUNDING_MODE )
89+ decimalVal = decimal.setScale(scale, ROUNDING_MODE )
12590 require(decimalVal.precision <= precision, " Overflowed precision" )
126- this .longVal = 0L
127- this ._precision = precision
128- this ._scale = scale
91+ _precision = precision
12992 this
13093 }
13194
@@ -134,9 +97,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
13497 */
13598 private [sql] def set (decimal : JavaBigDecimal ): Decimal = {
13699 this .decimalVal = decimal
137- this .longVal = 0L
138100 this ._precision = decimal.precision
139- this ._scale = decimal.scale
140101 this
141102 }
142103
@@ -145,52 +106,36 @@ final class Decimal extends Ordered[Decimal] with Serializable {
145106 */
146107 def set (decimal : Decimal ): Decimal = {
147108 this .decimalVal = decimal.decimalVal
148- this .longVal = decimal.longVal
149109 this ._precision = decimal._precision
150- this ._scale = decimal._scale
151110 this
152111 }
153112
154113 def toBigDecimal : BigDecimal = BigDecimal (toJavaBigDecimal)
155114
156- private [sql] def toJavaBigDecimal : JavaBigDecimal = {
157- if (decimalVal.ne(null )) {
158- decimalVal
159- } else {
160- JavaBigDecimal .valueOf(longVal, _scale)
161- }
162- }
115+ private [sql] def toJavaBigDecimal : JavaBigDecimal = decimalVal
163116
164117 def toUnscaledLong : Long = {
165- if (decimalVal.ne(null )) {
166- decimalVal.unscaledValue().longValue()
118+ val unscaled = PlatformDependent .UNSAFE .getLong(decimalVal,
119+ PlatformDependent .BIG_DECIMAL_INTCOMPACT_OFFSET )
120+ if (unscaled != Long .MinValue ) {
121+ unscaled
167122 } else {
168- longVal
123+ decimalVal.unscaledValue().longValue()
169124 }
170125 }
171126
172- override def toString : String = toJavaBigDecimal .toString()
127+ override def toString : String = decimalVal .toString()
173128
174129 @ DeveloperApi
175130 def toDebugString : String = {
176- if (decimalVal.ne(null )) {
177- s " Decimal(expanded, $decimalVal, $precision, $scale}) "
178- } else {
179- s " Decimal(compact, $longVal, $precision, $scale}) "
180- }
131+ s " Decimal( $decimalVal, ${_precision}) "
181132 }
182133
183134 def toDouble : Double = toJavaBigDecimal.doubleValue()
184135
185136 def toFloat : Float = toJavaBigDecimal.floatValue()
186137
187- def toLong : Long = {
188- if (decimalVal.eq(null )) {
189- longVal / POW_10 (_scale)
190- } else {
191- decimalVal.longValue()
192- }
193- }
138+ def toLong : Long = decimalVal.longValue()
194139
195140 def toInt : Int = toLong.toInt
196141
@@ -205,65 +150,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
205150 */
206151 def changePrecision (precision : Int , scale : Int ): Boolean = {
207152 // fast path for UnsafeProjection
208- if (precision == this .precision && scale == this .scale) {
153+ if (precision == _precision && scale == decimalVal .scale() ) {
209154 return true
210155 }
211- // First, update our longVal if we can, or transfer over to using a BigDecimal
212- if (decimalVal.eq(null )) {
213- if (scale < _scale) {
214- // Easier case: we just need to divide our scale down
215- val diff = _scale - scale
216- val droppedDigits = longVal % POW_10 (diff)
217- longVal /= POW_10 (diff)
218- if (math.abs(droppedDigits) * 2 >= POW_10 (diff)) {
219- longVal += (if (longVal < 0 ) - 1L else 1L )
220- }
221- } else if (scale > _scale) {
222- // We might be able to multiply longVal by a power of 10 and not overflow, but if not,
223- // switch to using a BigDecimal
224- val diff = scale - _scale
225- val p = POW_10 (math.max(MAX_LONG_DIGITS - diff, 0 ))
226- if (diff <= MAX_LONG_DIGITS && longVal > - p && longVal < p) {
227- // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS
228- longVal *= POW_10 (diff)
229- } else {
230- // Give up on using Longs; switch to BigDecimal, which we'll modify below
231- decimalVal = JavaBigDecimal .valueOf(longVal, _scale)
232- }
233- }
234- // In both cases, we will check whether our precision is okay below
235- }
236156
237- if (decimalVal.ne(null )) {
238- // We get here if either we started with a BigDecimal, or we switched to one because we would
239- // have overflowed our Long; in either case we must rescale decimalVal to the new scale.
240- val newVal = decimalVal.setScale(scale, ROUNDING_MODE )
241- if (newVal.precision > precision) {
242- return false
243- }
244- decimalVal = newVal
245- } else {
246- // We're still using Longs, but we should check whether we match the new precision
247- val p = POW_10 (math.min(precision, MAX_LONG_DIGITS ))
248- if (longVal <= - p || longVal >= p) {
249- // Note that we shouldn't have been able to fix this by switching to BigDecimal
250- return false
251- }
157+ val newVal = decimalVal.setScale(scale, ROUNDING_MODE )
158+ if (newVal.precision > precision) {
159+ return false
252160 }
253-
161+ decimalVal = newVal
254162 _precision = precision
255- _scale = scale
256163 true
257164 }
258165
259166 override def clone (): Decimal = new Decimal ().set(this )
260167
261168 override def compare (other : Decimal ): Int = {
262- if (decimalVal.eq(null ) && other.decimalVal.eq(null ) && _scale == other._scale) {
263- if (longVal < other.longVal) - 1 else if (longVal == other.longVal) 0 else 1
264- } else {
265- toJavaBigDecimal.compareTo(other.toJavaBigDecimal)
266- }
169+ toJavaBigDecimal.compareTo(other.toJavaBigDecimal)
267170 }
268171
269172 override def equals (other : Any ): Boolean = other match {
@@ -276,24 +179,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
276179 override def hashCode (): Int = toBigDecimal.hashCode()
277180
278181 def isZero : Boolean = {
279- if (decimalVal.ne(null )) decimalVal.compareTo(BIG_DEC_ZERO ) == 0
280- else longVal == 0
182+ decimalVal.compareTo(BIG_DEC_ZERO ) == 0
281183 }
282184
283185 def + (that : Decimal ): Decimal = {
284- if (decimalVal.eq(null ) && that.decimalVal.eq(null ) && scale == that.scale) {
285- Decimal (longVal + that.longVal, Math .max(precision, that.precision), scale)
286- } else {
287- Decimal (toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT ), precision, scale)
288- }
186+ Decimal (toJavaBigDecimal.add(that.toJavaBigDecimal, MATH_CONTEXT ), precision, scale)
289187 }
290188
291189 def - (that : Decimal ): Decimal = {
292- if (decimalVal.eq(null ) && that.decimalVal.eq(null ) && scale == that.scale) {
293- Decimal (longVal - that.longVal, Math .max(precision, that.precision), scale)
294- } else {
295- Decimal (toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT ), precision, scale)
296- }
190+ Decimal (toJavaBigDecimal.subtract(that.toJavaBigDecimal, MATH_CONTEXT ), precision, scale)
297191 }
298192
299193 // HiveTypeCoercion will take care of the precision, scale of result
@@ -313,11 +207,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
313207 def remainder (that : Decimal ): Decimal = this % that
314208
315209 def unary_- : Decimal = {
316- if (decimalVal.ne(null )) {
317- Decimal (decimalVal.negate(), precision, scale)
318- } else {
319- Decimal (- longVal, precision, scale)
320- }
210+ Decimal (decimalVal.negate(), precision, scale)
321211 }
322212
323213 def abs : Decimal = if (this .compare(Decimal .ZERO ) < 0 ) this .unary_- else this
0 commit comments