Skip to content

Commit 0b4fc14

Browse files
lianchengyhuai
authored andcommitted
[SC-5403] Backport ObjectHashAggregate operator to branch-2.1
## What changes were proposed in this pull request? This PR backports the following four open source commits to our branch-2.1: apache@2e80990 apache@205e6d5 apache@e0deee1 apache@2e80990 ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes apache#143 from yhuai/branch-2.1-with-object-hash-1.
1 parent 5b5159e commit 0b4fc14

File tree

10 files changed

+1552
-11
lines changed

10 files changed

+1552
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.expressions.aggregate._
2222
import org.apache.spark.sql.execution.SparkPlan
2323
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
24+
import org.apache.spark.sql.internal.SQLConf
2425

2526
/**
2627
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -66,14 +67,28 @@ object AggUtils {
6667
resultExpressions = resultExpressions,
6768
child = child)
6869
} else {
69-
SortAggregateExec(
70-
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
71-
groupingExpressions = groupingExpressions,
72-
aggregateExpressions = aggregateExpressions,
73-
aggregateAttributes = aggregateAttributes,
74-
initialInputBufferOffset = initialInputBufferOffset,
75-
resultExpressions = resultExpressions,
76-
child = child)
70+
val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
71+
val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
72+
73+
if (objectHashEnabled && useObjectHash) {
74+
ObjectHashAggregateExec(
75+
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
76+
groupingExpressions = groupingExpressions,
77+
aggregateExpressions = aggregateExpressions,
78+
aggregateAttributes = aggregateAttributes,
79+
initialInputBufferOffset = initialInputBufferOffset,
80+
resultExpressions = resultExpressions,
81+
child = child)
82+
} else {
83+
SortAggregateExec(
84+
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
85+
groupingExpressions = groupingExpressions,
86+
aggregateExpressions = aggregateExpressions,
87+
aggregateAttributes = aggregateAttributes,
88+
initialInputBufferOffset = initialInputBufferOffset,
89+
resultExpressions = resultExpressions,
90+
child = child)
91+
}
7792
}
7893
}
7994

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate
19+
20+
import org.apache.spark.{SparkEnv, TaskContext}
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate._
25+
import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
26+
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
27+
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.StructType
29+
import org.apache.spark.unsafe.KVIterator
30+
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
31+
32+
class ObjectAggregationIterator(
33+
outputAttributes: Seq[Attribute],
34+
groupingExpressions: Seq[NamedExpression],
35+
aggregateExpressions: Seq[AggregateExpression],
36+
aggregateAttributes: Seq[Attribute],
37+
initialInputBufferOffset: Int,
38+
resultExpressions: Seq[NamedExpression],
39+
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
40+
originalInputAttributes: Seq[Attribute],
41+
inputRows: Iterator[InternalRow],
42+
fallbackCountThreshold: Int)
43+
extends AggregationIterator(
44+
groupingExpressions,
45+
originalInputAttributes,
46+
aggregateExpressions,
47+
aggregateAttributes,
48+
initialInputBufferOffset,
49+
resultExpressions,
50+
newMutableProjection) with Logging {
51+
52+
// Indicates whether we have fallen back to sort-based aggregation or not.
53+
private[this] var sortBased: Boolean = false
54+
55+
private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _
56+
57+
// Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers
58+
private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
59+
val newExpressions = aggregateExpressions.map {
60+
case agg @ AggregateExpression(_, Partial, _, _) =>
61+
agg.copy(mode = PartialMerge)
62+
case agg @ AggregateExpression(_, Complete, _, _) =>
63+
agg.copy(mode = Final)
64+
case other => other
65+
}
66+
val newFunctions = initializeAggregateFunctions(newExpressions, 0)
67+
val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
68+
generateProcessRow(newExpressions, newFunctions, newInputAttributes)
69+
}
70+
71+
// A safe projection used to do deep clone of input rows to prevent false sharing.
72+
private[this] val safeProjection: Projection =
73+
FromUnsafeProjection(outputAttributes.map(_.dataType))
74+
75+
/**
76+
* Start processing input rows.
77+
*/
78+
processInputs()
79+
80+
override final def hasNext: Boolean = {
81+
aggBufferIterator.hasNext
82+
}
83+
84+
override final def next(): UnsafeRow = {
85+
val entry = aggBufferIterator.next()
86+
generateOutput(entry.groupingKey, entry.aggregationBuffer)
87+
}
88+
89+
/**
90+
* Generate an output row when there is no input and there is no grouping expression.
91+
*/
92+
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
93+
if (groupingExpressions.isEmpty) {
94+
val defaultAggregationBuffer = createNewAggregationBuffer()
95+
generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer)
96+
} else {
97+
throw new IllegalStateException(
98+
"This method should not be called when groupingExpressions is not empty.")
99+
}
100+
}
101+
102+
// Creates a new aggregation buffer and initializes buffer values. This function should only be
103+
// called under two cases:
104+
//
105+
// - when creating aggregation buffer for a new group in the hash map, and
106+
// - when creating the re-used buffer for sort-based aggregation
107+
private def createNewAggregationBuffer(): SpecificInternalRow = {
108+
val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
109+
val buffer = new SpecificInternalRow(bufferFieldTypes)
110+
initAggregationBuffer(buffer)
111+
buffer
112+
}
113+
114+
private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
115+
// Initializes declarative aggregates' buffer values
116+
expressionAggInitialProjection.target(buffer)(EmptyRow)
117+
// Initializes imperative aggregates' buffer values
118+
aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer))
119+
}
120+
121+
private def getAggregationBufferByKey(
122+
hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
123+
var aggBuffer = hashMap.getAggregationBuffer(groupingKey)
124+
125+
if (aggBuffer == null) {
126+
aggBuffer = createNewAggregationBuffer()
127+
hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
128+
}
129+
130+
aggBuffer
131+
}
132+
133+
// This function is used to read and process input rows. When processing input rows, it first uses
134+
// hash-based aggregation by putting groups and their buffers in `hashMap`. If `hashMap` grows too
135+
// large, it sorts the contents, spills them to disk, and creates a new map. At last, all sorted
136+
// spills are merged together for sort-based aggregation.
137+
private def processInputs(): Unit = {
138+
// In-memory map to store aggregation buffer for hash-based aggregation.
139+
val hashMap = new ObjectAggregationMap()
140+
141+
// If in-memory map is unable to stores all aggregation buffer, fallback to sort-based
142+
// aggregation backed by sorted physical storage.
143+
var sortBasedAggregationStore: SortBasedAggregator = null
144+
145+
if (groupingExpressions.isEmpty) {
146+
// If there is no grouping expressions, we can just reuse the same buffer over and over again.
147+
val groupingKey = groupingProjection.apply(null)
148+
val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
149+
while (inputRows.hasNext) {
150+
val newInput = safeProjection(inputRows.next())
151+
processRow(buffer, newInput)
152+
}
153+
} else {
154+
while (inputRows.hasNext && !sortBased) {
155+
val newInput = safeProjection(inputRows.next())
156+
val groupingKey = groupingProjection.apply(newInput)
157+
val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
158+
processRow(buffer, newInput)
159+
160+
// The the hash map gets too large, makes a sorted spill and clear the map.
161+
if (hashMap.size >= fallbackCountThreshold) {
162+
logInfo(
163+
s"Aggregation hash map reaches threshold " +
164+
s"capacity ($fallbackCountThreshold entries), spilling and falling back to sort" +
165+
s" based aggregation. You may change the threshold by adjust option " +
166+
SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key
167+
)
168+
169+
// Falls back to sort-based aggregation
170+
sortBased = true
171+
172+
}
173+
}
174+
175+
if (sortBased) {
176+
val sortIteratorFromHashMap = hashMap
177+
.dumpToExternalSorter(groupingAttributes, aggregateFunctions)
178+
.sortedIterator()
179+
sortBasedAggregationStore = new SortBasedAggregator(
180+
sortIteratorFromHashMap,
181+
StructType.fromAttributes(originalInputAttributes),
182+
StructType.fromAttributes(groupingAttributes),
183+
processRow,
184+
mergeAggregationBuffers,
185+
createNewAggregationBuffer())
186+
187+
while (inputRows.hasNext) {
188+
// NOTE: The input row is always UnsafeRow
189+
val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow]
190+
val groupingKey = groupingProjection.apply(unsafeInputRow)
191+
sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow)
192+
}
193+
}
194+
}
195+
196+
if (sortBased) {
197+
aggBufferIterator = sortBasedAggregationStore.destructiveIterator()
198+
} else {
199+
aggBufferIterator = hashMap.iterator
200+
}
201+
}
202+
}
203+
204+
/**
205+
* A class used to handle sort-based aggregation, used together with [[ObjectHashAggregateExec]].
206+
*
207+
* @param initialAggBufferIterator iterator that points to sorted input aggregation buffers
208+
* @param inputSchema The schema of input row
209+
* @param groupingSchema The schema of grouping key
210+
* @param processRow Function to update the aggregation buffer with input rows
211+
* @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing
212+
* aggregation buffers
213+
* @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
214+
*
215+
* @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]].
216+
*/
217+
class SortBasedAggregator(
218+
initialAggBufferIterator: KVIterator[UnsafeRow, UnsafeRow],
219+
inputSchema: StructType,
220+
groupingSchema: StructType,
221+
processRow: (InternalRow, InternalRow) => Unit,
222+
mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
223+
makeEmptyAggregationBuffer: => InternalRow) {
224+
225+
// external sorter to sort the input (grouping key + input row) with grouping key.
226+
private val inputSorter = createExternalSorterForInput()
227+
private val groupingKeyOrdering: BaseOrdering = GenerateOrdering.create(groupingSchema)
228+
229+
def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
230+
inputSorter.insertKV(groupingKey, inputRow)
231+
}
232+
233+
/**
234+
* Returns a destructive iterator of AggregationBufferEntry.
235+
* Notice: it is illegal to call any method after `destructiveIterator()` has been called.
236+
*/
237+
def destructiveIterator(): Iterator[AggregationBufferEntry] = {
238+
new Iterator[AggregationBufferEntry] {
239+
val inputIterator = inputSorter.sortedIterator()
240+
var hasNextInput: Boolean = inputIterator.next()
241+
var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
242+
private var result: AggregationBufferEntry = _
243+
private var groupingKey: UnsafeRow = _
244+
245+
override def hasNext(): Boolean = {
246+
result != null || findNextSortedGroup()
247+
}
248+
249+
override def next(): AggregationBufferEntry = {
250+
val returnResult = result
251+
result = null
252+
returnResult
253+
}
254+
255+
// Two-way merges initialAggBufferIterator and inputIterator
256+
private def findNextSortedGroup(): Boolean = {
257+
if (hasNextInput || hasNextAggBuffer) {
258+
// Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
259+
groupingKey = findGroupingKey()
260+
result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer)
261+
262+
// Firstly, update the aggregation buffer with input rows.
263+
while (hasNextInput &&
264+
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
265+
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
266+
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
267+
processRow(result.aggregationBuffer, inputIterator.getValue.copy())
268+
hasNextInput = inputIterator.next()
269+
}
270+
271+
// Secondly, merge the aggregation buffer with existing aggregation buffers.
272+
// NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
273+
// be called after calling processRow.
274+
while (hasNextAggBuffer &&
275+
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
276+
mergeAggregationBuffers(
277+
result.aggregationBuffer,
278+
// Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
279+
// overwritten when `inputIterator` steps forward, we need to do a deep copy here.
280+
initialAggBufferIterator.getValue.copy()
281+
)
282+
hasNextAggBuffer = initialAggBufferIterator.next()
283+
}
284+
285+
true
286+
} else {
287+
false
288+
}
289+
}
290+
291+
private def findGroupingKey(): UnsafeRow = {
292+
var newGroupingKey: UnsafeRow = null
293+
if (!hasNextInput) {
294+
newGroupingKey = initialAggBufferIterator.getKey
295+
} else if (!hasNextAggBuffer) {
296+
newGroupingKey = inputIterator.getKey
297+
} else {
298+
val compareResult =
299+
groupingKeyOrdering.compare(inputIterator.getKey, initialAggBufferIterator.getKey)
300+
if (compareResult <= 0) {
301+
newGroupingKey = inputIterator.getKey
302+
} else {
303+
newGroupingKey = initialAggBufferIterator.getKey
304+
}
305+
}
306+
307+
if (groupingKey == null) {
308+
groupingKey = newGroupingKey.copy()
309+
} else {
310+
groupingKey.copyFrom(newGroupingKey)
311+
}
312+
groupingKey
313+
}
314+
}
315+
}
316+
317+
private def createExternalSorterForInput(): UnsafeKVExternalSorter = {
318+
new UnsafeKVExternalSorter(
319+
groupingSchema,
320+
inputSchema,
321+
SparkEnv.get.blockManager,
322+
SparkEnv.get.serializerManager,
323+
TaskContext.get().taskMemoryManager().pageSizeBytes,
324+
SparkEnv.get.conf.getLong(
325+
"spark.shuffle.spill.numElementsForceSpillThreshold",
326+
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
327+
null
328+
)
329+
}
330+
}

0 commit comments

Comments
 (0)