Skip to content

Commit 5b05966

Browse files
jiangxb1987cloud-fan
authored andcommitted
[SPARK-24564][TEST] Add test suite for RecordBinaryComparator
## What changes were proposed in this pull request? Add a new test suite to test RecordBinaryComparator. ## How was this patch tested? New test suite. Author: Xingbo Jiang <xingbo.jiang@databricks.com> Closes #21570 from jiangxb1987/rbc-test.
1 parent 6a97e8e commit 5b05966

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.memory;
1919

20+
import com.google.common.annotations.VisibleForTesting;
21+
22+
import org.apache.spark.unsafe.memory.MemoryBlock;
23+
2024
import java.io.IOException;
2125

2226
public class TestMemoryConsumer extends MemoryConsumer {
@@ -43,6 +47,12 @@ void free(long size) {
4347
used -= size;
4448
taskMemoryManager.releaseExecutionMemory(size, this);
4549
}
50+
51+
@VisibleForTesting
52+
public void freePage(MemoryBlock page) {
53+
used -= page.size();
54+
taskMemoryManager.freePage(page, this);
55+
}
4656
}
4757

4858

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql.execution.sort;
19+
20+
import org.apache.spark.SparkConf;
21+
import org.apache.spark.memory.TaskMemoryManager;
22+
import org.apache.spark.memory.TestMemoryConsumer;
23+
import org.apache.spark.memory.TestMemoryManager;
24+
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
25+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
26+
import org.apache.spark.sql.execution.RecordBinaryComparator;
27+
import org.apache.spark.unsafe.Platform;
28+
import org.apache.spark.unsafe.UnsafeAlignedOffset;
29+
import org.apache.spark.unsafe.array.LongArray;
30+
import org.apache.spark.unsafe.memory.MemoryBlock;
31+
import org.apache.spark.unsafe.types.UTF8String;
32+
import org.apache.spark.util.collection.unsafe.sort.*;
33+
34+
import org.junit.After;
35+
import org.junit.Before;
36+
import org.junit.Test;
37+
38+
/**
39+
* Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form.
40+
*/
41+
public class RecordBinaryComparatorSuite {
42+
43+
private final TaskMemoryManager memoryManager = new TaskMemoryManager(
44+
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
45+
private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
46+
47+
private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
48+
49+
private MemoryBlock dataPage;
50+
private long pageCursor;
51+
52+
private LongArray array;
53+
private int pos;
54+
55+
@Before
56+
public void beforeEach() {
57+
// Only compare between two input rows.
58+
array = consumer.allocateArray(2);
59+
pos = 0;
60+
61+
dataPage = memoryManager.allocatePage(4096, consumer);
62+
pageCursor = dataPage.getBaseOffset();
63+
}
64+
65+
@After
66+
public void afterEach() {
67+
consumer.freePage(dataPage);
68+
dataPage = null;
69+
pageCursor = 0;
70+
71+
consumer.freeArray(array);
72+
array = null;
73+
pos = 0;
74+
}
75+
76+
private void insertRow(UnsafeRow row) {
77+
Object recordBase = row.getBaseObject();
78+
long recordOffset = row.getBaseOffset();
79+
int recordLength = row.getSizeInBytes();
80+
81+
Object baseObject = dataPage.getBaseObject();
82+
assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size());
83+
long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor);
84+
UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
85+
pageCursor += uaoSize;
86+
Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength);
87+
pageCursor += recordLength;
88+
89+
assert(pos < 2);
90+
array.set(pos, recordAddress);
91+
pos++;
92+
}
93+
94+
private int compare(int index1, int index2) {
95+
Object baseObject = dataPage.getBaseObject();
96+
97+
long recordAddress1 = array.get(index1);
98+
long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize;
99+
int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize);
100+
101+
long recordAddress2 = array.get(index2);
102+
long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize;
103+
int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize);
104+
105+
return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject,
106+
baseOffset2, recordLength2);
107+
}
108+
109+
private final RecordComparator binaryComparator = new RecordBinaryComparator();
110+
111+
// Compute the most compact size for UnsafeRow's backing data.
112+
private int computeSizeInBytes(int originalSize) {
113+
// All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall
114+
// always be 8.
115+
return 8 + (originalSize + 7) / 8 * 8;
116+
}
117+
118+
// Compute the relative offset of variable-length values.
119+
private long relativeOffset(int numFields) {
120+
// All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall
121+
// always be 8.
122+
return 8 + numFields * 8L;
123+
}
124+
125+
@Test
126+
public void testBinaryComparatorForSingleColumnRow() throws Exception {
127+
int numFields = 1;
128+
129+
UnsafeRow row1 = new UnsafeRow(numFields);
130+
byte[] data1 = new byte[100];
131+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
132+
row1.setInt(0, 11);
133+
134+
UnsafeRow row2 = new UnsafeRow(numFields);
135+
byte[] data2 = new byte[100];
136+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
137+
row2.setInt(0, 42);
138+
139+
insertRow(row1);
140+
insertRow(row2);
141+
142+
assert(compare(0, 0) == 0);
143+
assert(compare(0, 1) < 0);
144+
}
145+
146+
@Test
147+
public void testBinaryComparatorForMultipleColumnRow() throws Exception {
148+
int numFields = 5;
149+
150+
UnsafeRow row1 = new UnsafeRow(numFields);
151+
byte[] data1 = new byte[100];
152+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
153+
for (int i = 0; i < numFields; i++) {
154+
row1.setDouble(i, i * 3.14);
155+
}
156+
157+
UnsafeRow row2 = new UnsafeRow(numFields);
158+
byte[] data2 = new byte[100];
159+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
160+
for (int i = 0; i < numFields; i++) {
161+
row2.setDouble(i, 198.7 / (i + 1));
162+
}
163+
164+
insertRow(row1);
165+
insertRow(row2);
166+
167+
assert(compare(0, 0) == 0);
168+
assert(compare(0, 1) < 0);
169+
}
170+
171+
@Test
172+
public void testBinaryComparatorForArrayColumn() throws Exception {
173+
int numFields = 1;
174+
175+
UnsafeRow row1 = new UnsafeRow(numFields);
176+
byte[] data1 = new byte[100];
177+
UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1});
178+
row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes()));
179+
row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes());
180+
Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1,
181+
row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes());
182+
183+
UnsafeRow row2 = new UnsafeRow(numFields);
184+
byte[] data2 = new byte[100];
185+
UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22});
186+
row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes()));
187+
row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes());
188+
Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2,
189+
row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes());
190+
191+
insertRow(row1);
192+
insertRow(row2);
193+
194+
assert(compare(0, 0) == 0);
195+
assert(compare(0, 1) > 0);
196+
}
197+
198+
@Test
199+
public void testBinaryComparatorForMixedColumns() throws Exception {
200+
int numFields = 4;
201+
202+
UnsafeRow row1 = new UnsafeRow(numFields);
203+
byte[] data1 = new byte[100];
204+
UTF8String str1 = UTF8String.fromString("Milk tea");
205+
row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes()));
206+
row1.setInt(0, 11);
207+
row1.setDouble(1, 3.14);
208+
row1.setInt(2, -1);
209+
row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes());
210+
Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1,
211+
row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes());
212+
213+
UnsafeRow row2 = new UnsafeRow(numFields);
214+
byte[] data2 = new byte[100];
215+
UTF8String str2 = UTF8String.fromString("Java");
216+
row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes()));
217+
row2.setInt(0, 11);
218+
row2.setDouble(1, 3.14);
219+
row2.setInt(2, -1);
220+
row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes());
221+
Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2,
222+
row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes());
223+
224+
insertRow(row1);
225+
insertRow(row2);
226+
227+
assert(compare(0, 0) == 0);
228+
assert(compare(0, 1) > 0);
229+
}
230+
231+
@Test
232+
public void testBinaryComparatorForNullColumns() throws Exception {
233+
int numFields = 3;
234+
235+
UnsafeRow row1 = new UnsafeRow(numFields);
236+
byte[] data1 = new byte[100];
237+
row1.pointTo(data1, computeSizeInBytes(numFields * 8));
238+
for (int i = 0; i < numFields; i++) {
239+
row1.setNullAt(i);
240+
}
241+
242+
UnsafeRow row2 = new UnsafeRow(numFields);
243+
byte[] data2 = new byte[100];
244+
row2.pointTo(data2, computeSizeInBytes(numFields * 8));
245+
for (int i = 0; i < numFields - 1; i++) {
246+
row2.setNullAt(i);
247+
}
248+
row2.setDouble(numFields - 1, 3.14);
249+
250+
insertRow(row1);
251+
insertRow(row2);
252+
253+
assert(compare(0, 0) == 0);
254+
assert(compare(0, 1) > 0);
255+
}
256+
}

0 commit comments

Comments
 (0)