Skip to content

Commit 7dc6a34

Browse files
committed
Adds more in-memory table statistics and propagates them properly
1 parent 7e63bb4 commit 7dc6a34

File tree

9 files changed

+164
-119
lines changed

9 files changed

+164
-119
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,17 @@ package org.apache.spark.sql.catalyst.expressions
2323
* of the name, or the expected nullability).
2424
*/
2525
object AttributeMap {
26-
def apply[A](kvs: Seq[(Attribute, A)]) =
27-
new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap)
26+
def apply[A](kvs: Seq[(Attribute, A)]) = new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
2827
}
2928

3029
class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])
3130
extends Map[Attribute, A] with Serializable {
3231

3332
override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)
3433

35-
override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] =
36-
(baseMap.map(_._2) + kv).toMap
34+
override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv
3735

38-
override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator
36+
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
3937

40-
override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap
38+
override def -(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
4139
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,24 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
2626
import org.apache.spark.sql.catalyst.types.StructType
2727
import org.apache.spark.sql.catalyst.trees
2828

29+
/**
30+
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
31+
* corresponding statistic produced by the children. To override this behavior, override
32+
* `statistics` and assign it an overriden version of `Statistics`.
33+
*
34+
* '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
35+
* performance of the implementations. The reason is that estimations might get triggered in
36+
* performance-critical processes, such as query plan planning.
37+
*
38+
* @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
39+
* defaults to the product of children's `sizeInBytes`.
40+
*/
41+
case class Statistics(sizeInBytes: BigInt)
42+
2943
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
3044
self: Product =>
3145

32-
/**
33-
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
34-
* corresponding statistic produced by the children. To override this behavior, override
35-
* `statistics` and assign it an overriden version of `Statistics`.
36-
*
37-
* '''NOTE''': concrete and/or overriden versions of statistics fields should pay attention to the
38-
* performance of the implementations. The reason is that estimations might get triggered in
39-
* performance-critical processes, such as query plan planning.
40-
*
41-
* @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
42-
* defaults to the product of children's `sizeInBytes`.
43-
*/
44-
case class Statistics(
45-
sizeInBytes: BigInt
46-
)
47-
lazy val statistics: Statistics = {
46+
def statistics: Statistics = {
4847
if (children.size == 0) {
4948
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
5049
}

sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, Attri
2424
import org.apache.spark.sql.catalyst.types._
2525

2626
private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
27-
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)()
28-
val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)()
29-
val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
27+
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
28+
val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)()
29+
val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)()
30+
val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)()
31+
val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)()
3032

31-
val schema = Seq(lowerBound, upperBound, nullCount)
33+
val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes)
3234
}
3335

3436
private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable {
@@ -45,6 +47,10 @@ private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Seri
4547
* brings significant performance penalty.
4648
*/
4749
private[sql] sealed trait ColumnStats extends Serializable {
50+
protected var count = 0
51+
protected var nullCount = 0
52+
protected var sizeInBytes = 0L
53+
4854
/**
4955
* Gathers statistics information from `row(ordinal)`.
5056
*/
@@ -65,9 +71,8 @@ private[sql] class NoopColumnStats extends ColumnStats {
6571
}
6672

6773
private[sql] class ByteColumnStats extends ColumnStats {
68-
var upper = Byte.MinValue
69-
var lower = Byte.MaxValue
70-
var nullCount = 0
74+
protected var upper = Byte.MinValue
75+
protected var lower = Byte.MaxValue
7176

7277
override def gatherStats(row: Row, ordinal: Int): Unit = {
7378
if (!row.isNullAt(ordinal)) {
@@ -77,15 +82,16 @@ private[sql] class ByteColumnStats extends ColumnStats {
7782
} else {
7883
nullCount += 1
7984
}
85+
count += 1
86+
sizeInBytes += BYTE.defaultSize
8087
}
8188

82-
def collectedStatistics = Row(lower, upper, nullCount)
89+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
8390
}
8491

8592
private[sql] class ShortColumnStats extends ColumnStats {
86-
var upper = Short.MinValue
87-
var lower = Short.MaxValue
88-
var nullCount = 0
93+
protected var upper = Short.MinValue
94+
protected var lower = Short.MaxValue
8995

9096
override def gatherStats(row: Row, ordinal: Int): Unit = {
9197
if (!row.isNullAt(ordinal)) {
@@ -95,15 +101,16 @@ private[sql] class ShortColumnStats extends ColumnStats {
95101
} else {
96102
nullCount += 1
97103
}
104+
count += 1
105+
sizeInBytes += SHORT.defaultSize
98106
}
99107

100-
def collectedStatistics = Row(lower, upper, nullCount)
108+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
101109
}
102110

103111
private[sql] class LongColumnStats extends ColumnStats {
104-
var upper = Long.MinValue
105-
var lower = Long.MaxValue
106-
var nullCount = 0
112+
protected var upper = Long.MinValue
113+
protected var lower = Long.MaxValue
107114

108115
override def gatherStats(row: Row, ordinal: Int): Unit = {
109116
if (!row.isNullAt(ordinal)) {
@@ -113,15 +120,16 @@ private[sql] class LongColumnStats extends ColumnStats {
113120
} else {
114121
nullCount += 1
115122
}
123+
count += 1
124+
sizeInBytes += LONG.defaultSize
116125
}
117126

118-
def collectedStatistics = Row(lower, upper, nullCount)
127+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
119128
}
120129

121130
private[sql] class DoubleColumnStats extends ColumnStats {
122-
var upper = Double.MinValue
123-
var lower = Double.MaxValue
124-
var nullCount = 0
131+
protected var upper = Double.MinValue
132+
protected var lower = Double.MaxValue
125133

126134
override def gatherStats(row: Row, ordinal: Int): Unit = {
127135
if (!row.isNullAt(ordinal)) {
@@ -131,15 +139,16 @@ private[sql] class DoubleColumnStats extends ColumnStats {
131139
} else {
132140
nullCount += 1
133141
}
142+
count += 1
143+
sizeInBytes += DOUBLE.defaultSize
134144
}
135145

136-
def collectedStatistics = Row(lower, upper, nullCount)
146+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
137147
}
138148

139149
private[sql] class FloatColumnStats extends ColumnStats {
140-
var upper = Float.MinValue
141-
var lower = Float.MaxValue
142-
var nullCount = 0
150+
protected var upper = Float.MinValue
151+
protected var lower = Float.MaxValue
143152

144153
override def gatherStats(row: Row, ordinal: Int): Unit = {
145154
if (!row.isNullAt(ordinal)) {
@@ -149,15 +158,16 @@ private[sql] class FloatColumnStats extends ColumnStats {
149158
} else {
150159
nullCount += 1
151160
}
161+
count += 1
162+
sizeInBytes += FLOAT.defaultSize
152163
}
153164

154-
def collectedStatistics = Row(lower, upper, nullCount)
165+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
155166
}
156167

157168
private[sql] class IntColumnStats extends ColumnStats {
158-
var upper = Int.MinValue
159-
var lower = Int.MaxValue
160-
var nullCount = 0
169+
protected var upper = Int.MinValue
170+
protected var lower = Int.MaxValue
161171

162172
override def gatherStats(row: Row, ordinal: Int): Unit = {
163173
if (!row.isNullAt(ordinal)) {
@@ -167,15 +177,16 @@ private[sql] class IntColumnStats extends ColumnStats {
167177
} else {
168178
nullCount += 1
169179
}
180+
count += 1
181+
sizeInBytes += INT.defaultSize
170182
}
171183

172-
def collectedStatistics = Row(lower, upper, nullCount)
184+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
173185
}
174186

175187
private[sql] class StringColumnStats extends ColumnStats {
176-
var upper: String = null
177-
var lower: String = null
178-
var nullCount = 0
188+
protected var upper: String = null
189+
protected var lower: String = null
179190

180191
override def gatherStats(row: Row, ordinal: Int): Unit = {
181192
if (!row.isNullAt(ordinal)) {
@@ -185,15 +196,16 @@ private[sql] class StringColumnStats extends ColumnStats {
185196
} else {
186197
nullCount += 1
187198
}
199+
count += 1
200+
sizeInBytes += STRING.actualSize(row, ordinal)
188201
}
189202

190-
def collectedStatistics = Row(lower, upper, nullCount)
203+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
191204
}
192205

193206
private[sql] class DateColumnStats extends ColumnStats {
194-
var upper: Date = null
195-
var lower: Date = null
196-
var nullCount = 0
207+
protected var upper: Date = null
208+
protected var lower: Date = null
197209

198210
override def gatherStats(row: Row, ordinal: Int) {
199211
if (!row.isNullAt(ordinal)) {
@@ -203,15 +215,16 @@ private[sql] class DateColumnStats extends ColumnStats {
203215
} else {
204216
nullCount += 1
205217
}
218+
count += 1
219+
sizeInBytes += DATE.defaultSize
206220
}
207221

208-
def collectedStatistics = Row(lower, upper, nullCount)
222+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
209223
}
210224

211225
private[sql] class TimestampColumnStats extends ColumnStats {
212-
var upper: Timestamp = null
213-
var lower: Timestamp = null
214-
var nullCount = 0
226+
protected var upper: Timestamp = null
227+
protected var lower: Timestamp = null
215228

216229
override def gatherStats(row: Row, ordinal: Int): Unit = {
217230
if (!row.isNullAt(ordinal)) {
@@ -221,7 +234,9 @@ private[sql] class TimestampColumnStats extends ColumnStats {
221234
} else {
222235
nullCount += 1
223236
}
237+
count += 1
238+
sizeInBytes += TIMESTAMP.defaultSize
224239
}
225240

226-
def collectedStatistics = Row(lower, upper, nullCount)
241+
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
227242
}

0 commit comments

Comments
 (0)