Skip to content

Commit 9cc98f5

Browse files
committed
Move more code to Java; fix bugs in UnsafeRowConverter length type.
The length type is an int, not long, but the code was inconsistent about this. I also now use byte arrays instead of long arrays in some places in order to avoid off-by-factor-of-8 errors.
1 parent c8792de commit 9cc98f5

File tree

4 files changed

+186
-84
lines changed

4 files changed

+186
-84
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ public long freeMemory() {
151151
return memoryFreed;
152152
}
153153

154-
private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
154+
private void ensureSpaceInDataPage(int requiredSpace) throws IOException {
155155
// TODO: merge these steps to first calculate total memory requirements for this insert,
156156
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
157157
// data page.
@@ -176,7 +176,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
176176
}
177177
if (requiredSpace > PAGE_SIZE) {
178178
// TODO: throw a more specific exception?
179-
throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
179+
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
180180
PAGE_SIZE + ")");
181181
} else if (requiredSpace > spaceInCurrentPage) {
182182
if (spillingEnabled) {
@@ -187,7 +187,7 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
187187
final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
188188
if (memoryAcquiredAfterSpill != PAGE_SIZE) {
189189
shuffleMemoryManager.release(memoryAcquiredAfterSpill);
190-
throw new Exception("Can't allocate memory!");
190+
throw new IOException("Can't allocate memory!");
191191
}
192192
}
193193
}
@@ -202,7 +202,7 @@ public void insertRecord(
202202
Object recordBaseObject,
203203
long recordBaseOffset,
204204
int lengthInBytes,
205-
long prefix) throws Exception {
205+
long prefix) throws IOException {
206206
// Need 4 bytes to store the record length.
207207
ensureSpaceInDataPage(lengthInBytes + 4);
208208

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions;
1919

20+
import java.math.BigDecimal;
21+
import java.sql.Date;
22+
import java.util.*;
23+
import javax.annotation.Nullable;
24+
2025
import org.apache.spark.sql.catalyst.InternalRow;
2126
import org.apache.spark.sql.catalyst.util.ObjectPool;
2227
import org.apache.spark.unsafe.PlatformDependent;
@@ -55,6 +60,8 @@
5560
*/
5661
public final class UnsafeRow extends MutableRow {
5762

63+
/** Hack for if we want to pass around an UnsafeRow which also carries around its backing data */
64+
@Nullable public byte[] backingArray;
5865
private Object baseObject;
5966
private long baseOffset;
6067

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
23+
import scala.Function1;
24+
import scala.collection.AbstractIterator;
25+
import scala.collection.Iterator;
26+
import scala.math.Ordering;
27+
28+
import org.apache.spark.SparkEnv;
29+
import org.apache.spark.TaskContext;
30+
import org.apache.spark.sql.Row;
31+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
32+
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
33+
import org.apache.spark.sql.types.StructType;
34+
import org.apache.spark.unsafe.PlatformDependent;
35+
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
36+
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
37+
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
38+
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
39+
40+
final class UnsafeExternalRowSorter {
41+
42+
private final StructType schema;
43+
private final UnsafeRowConverter rowConverter;
44+
private final RowComparator rowComparator;
45+
private final PrefixComparator prefixComparator;
46+
private final Function1<Row, Long> prefixComputer;
47+
48+
public UnsafeExternalRowSorter(
49+
StructType schema,
50+
Ordering<Row> ordering,
51+
PrefixComparator prefixComparator,
52+
// TODO: if possible, avoid this boxing of the return value
53+
Function1<Row, Long> prefixComputer) {
54+
this.schema = schema;
55+
this.rowConverter = new UnsafeRowConverter(schema);
56+
this.rowComparator = new RowComparator(ordering, schema);
57+
this.prefixComparator = prefixComparator;
58+
this.prefixComputer = prefixComputer;
59+
}
60+
61+
public Iterator<Row> sort(Iterator<Row> inputIterator) throws IOException {
62+
final SparkEnv sparkEnv = SparkEnv.get();
63+
final TaskContext taskContext = TaskContext.get();
64+
byte[] rowConversionBuffer = new byte[1024 * 8];
65+
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
66+
taskContext.taskMemoryManager(),
67+
sparkEnv.shuffleMemoryManager(),
68+
sparkEnv.blockManager(),
69+
taskContext,
70+
rowComparator,
71+
prefixComparator,
72+
4096,
73+
sparkEnv.conf()
74+
);
75+
try {
76+
while (inputIterator.hasNext()) {
77+
final Row row = inputIterator.next();
78+
final int sizeRequirement = rowConverter.getSizeRequirement(row);
79+
if (sizeRequirement > rowConversionBuffer.length) {
80+
rowConversionBuffer = new byte[sizeRequirement];
81+
} else {
82+
// Zero out the buffer that's used to hold the current row. This is necessary in order
83+
// to ensure that rows hash properly, since garbage data from the previous row could
84+
// otherwise end up as padding in this row. As a performance optimization, we only zero
85+
// out the portion of the buffer that we'll actually write to.
86+
Arrays.fill(rowConversionBuffer, 0, sizeRequirement, (byte) 0);
87+
}
88+
final int bytesWritten =
89+
rowConverter.writeRow(row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET);
90+
assert (bytesWritten == sizeRequirement);
91+
final long prefix = prefixComputer.apply(row);
92+
sorter.insertRecord(
93+
rowConversionBuffer,
94+
PlatformDependent.BYTE_ARRAY_OFFSET,
95+
sizeRequirement,
96+
prefix
97+
);
98+
}
99+
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
100+
return new AbstractIterator<Row>() {
101+
102+
private final int numFields = schema.length();
103+
private final UnsafeRow row = new UnsafeRow();
104+
105+
@Override
106+
public boolean hasNext() {
107+
return sortedIterator.hasNext();
108+
}
109+
110+
@Override
111+
public Row next() {
112+
try {
113+
sortedIterator.loadNext();
114+
if (hasNext()) {
115+
row.pointTo(
116+
sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, schema);
117+
return row;
118+
} else {
119+
final byte[] rowDataCopy = new byte[sortedIterator.getRecordLength()];
120+
PlatformDependent.copyMemory(
121+
sortedIterator.getBaseObject(),
122+
sortedIterator.getBaseOffset(),
123+
rowDataCopy,
124+
PlatformDependent.BYTE_ARRAY_OFFSET,
125+
sortedIterator.getRecordLength()
126+
);
127+
row.backingArray = rowDataCopy;
128+
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema);
129+
sorter.freeMemory();
130+
return row;
131+
}
132+
} catch (IOException e) {
133+
// TODO: we need to ensure that files are cleaned properly after an exception,
134+
// so we need better cleanup methods than freeMemory().
135+
sorter.freeMemory();
136+
// Scala iterators don't declare any checked exceptions, so we need to use this hack
137+
// to re-throw the exception:
138+
PlatformDependent.throwException(e);
139+
}
140+
throw new RuntimeException("Exception should have been re-thrown in next()");
141+
};
142+
};
143+
} catch (IOException e) {
144+
// TODO: we need to ensure that files are cleaned properly after an exception,
145+
// so we need better cleanup methods than freeMemory().
146+
sorter.freeMemory();
147+
throw e;
148+
}
149+
}
150+
151+
private static final class RowComparator extends RecordComparator {
152+
private final StructType schema;
153+
private final Ordering<Row> ordering;
154+
private final int numFields;
155+
private final UnsafeRow row1 = new UnsafeRow();
156+
private final UnsafeRow row2 = new UnsafeRow();
157+
158+
public RowComparator(Ordering<Row> ordering, StructType schema) {
159+
this.schema = schema;
160+
this.numFields = schema.length();
161+
this.ordering = ordering;
162+
}
163+
164+
@Override
165+
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
166+
row1.pointTo(baseObj1, baseOff1, numFields, schema);
167+
row2.pointTo(baseObj2, baseOff2, numFields, schema);
168+
return ordering.compare(row1, row2);
169+
}
170+
}
171+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 4 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.util.Arrays
21-
2220
import org.apache.spark.sql.types.StructType
23-
import org.apache.spark.unsafe.PlatformDependent
24-
import org.apache.spark.util.collection.unsafe.sort.{RecordComparator, PrefixComparator, UnsafeExternalSorter}
25-
import org.apache.spark.{TaskContext, SparkEnv, HashPartitioner, SparkConf}
21+
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
22+
import org.apache.spark.{SparkEnv, HashPartitioner}
2623
import org.apache.spark.annotation.DeveloperApi
2724
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2825
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -272,87 +269,14 @@ case class UnsafeExternalSort(
272269
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
273270

274271
protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
275-
// TODO(josh): This code is unreadably messy; this should be split into a separate file
276-
// and written in Java.
277272
assert (codegenEnabled)
278273
def doSort(iterator: Iterator[Row]): Iterator[Row] = {
279274
val ordering = newOrdering(sortOrder, child.output)
280-
val rowConverter = new UnsafeRowConverter(schema.map(_.dataType).toArray)
281-
var rowConversionScratchSpace = new Array[Long](1024)
282275
val prefixComparator = new PrefixComparator {
283276
override def compare(prefix1: Long, prefix2: Long): Int = 0
284277
}
285-
val recordComparator = new RecordComparator {
286-
private[this] val row1 = new UnsafeRow
287-
private[this] val row2 = new UnsafeRow
288-
override def compare(
289-
baseObj1: scala.Any, baseOff1: Long, baseObj2: scala.Any, baseOff2: Long): Int = {
290-
row1.pointTo(baseObj1, baseOff1, numFields, schema)
291-
row2.pointTo(baseObj2, baseOff2, numFields, schema)
292-
ordering.compare(row1, row2)
293-
}
294-
}
295-
val sorter = new UnsafeExternalSorter(
296-
TaskContext.get.taskMemoryManager(),
297-
SparkEnv.get.shuffleMemoryManager,
298-
SparkEnv.get.blockManager,
299-
TaskContext.get,
300-
recordComparator,
301-
prefixComparator,
302-
4096,
303-
SparkEnv.get.conf
304-
)
305-
while (iterator.hasNext) {
306-
val row: Row = iterator.next()
307-
val sizeRequirement = rowConverter.getSizeRequirement(row)
308-
if (sizeRequirement / 8 > rowConversionScratchSpace.length) {
309-
rowConversionScratchSpace = new Array[Long](sizeRequirement / 8)
310-
} else {
311-
// Zero out the buffer that's used to hold the current row. This is necessary in order
312-
// to ensure that rows hash properly, since garbage data from the previous row could
313-
// otherwise end up as padding in this row. As a performance optimization, we only zero
314-
// out the portion of the buffer that we'll actually write to.
315-
Arrays.fill(rowConversionScratchSpace, 0, sizeRequirement / 8, 0)
316-
}
317-
val bytesWritten =
318-
rowConverter.writeRow(row, rowConversionScratchSpace, PlatformDependent.LONG_ARRAY_OFFSET)
319-
assert (bytesWritten == sizeRequirement)
320-
val prefix: Long = 0 // dummy prefix until we implement prefix calculation
321-
sorter.insertRecord(
322-
rowConversionScratchSpace,
323-
PlatformDependent.LONG_ARRAY_OFFSET,
324-
sizeRequirement,
325-
prefix
326-
)
327-
}
328-
val sortedIterator = sorter.getSortedIterator
329-
// TODO: need to avoid memory leaks on exceptions, etc. by wrapping in resource cleanup blocks
330-
// TODO: need to clean up spill files after success or failure.
331-
new Iterator[Row] {
332-
private[this] val row = new UnsafeRow()
333-
override def hasNext: Boolean = sortedIterator.hasNext
334-
335-
override def next(): Row = {
336-
sortedIterator.loadNext()
337-
if (hasNext) {
338-
row.pointTo(
339-
sortedIterator.getBaseObject, sortedIterator.getBaseOffset, numFields, schema)
340-
row
341-
} else {
342-
val rowDataCopy = new Array[Byte](sortedIterator.getRecordLength)
343-
PlatformDependent.copyMemory(
344-
sortedIterator.getBaseObject,
345-
sortedIterator.getBaseOffset,
346-
rowDataCopy,
347-
PlatformDependent.BYTE_ARRAY_OFFSET,
348-
sortedIterator.getRecordLength
349-
)
350-
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema)
351-
sorter.freeMemory()
352-
row
353-
}
354-
}
355-
}
278+
def prefixComputer(row: Row): Long = 0
279+
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
356280
}
357281
child.execute().mapPartitions(doSort, preservesPartitioning = true)
358282
}

0 commit comments

Comments
 (0)