Skip to content

Commit 01afc74

Browse files
committed
Actually read data in UnsafeShuffleWriterSuite
1 parent 1929a74 commit 01afc74

File tree

1 file changed

+50
-21
lines changed

1 file changed

+50
-21
lines changed

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
package org.apache.spark.shuffle.unsafe;
1919

20-
import java.io.File;
21-
import java.io.IOException;
22-
import java.io.InputStream;
23-
import java.io.OutputStream;
20+
import java.io.*;
2421
import java.util.*;
2522

2623
import scala.*;
24+
import scala.collection.Iterator;
2725
import scala.runtime.AbstractFunction1;
2826

27+
import com.google.common.collect.HashMultiset;
28+
import com.google.common.io.ByteStreams;
2929
import org.junit.After;
3030
import org.junit.Assert;
3131
import org.junit.Before;
@@ -40,19 +40,21 @@
4040
import static org.mockito.Mockito.*;
4141

4242
import org.apache.spark.*;
43-
import org.apache.spark.serializer.Serializer;
44-
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
4543
import org.apache.spark.executor.ShuffleWriteMetrics;
4644
import org.apache.spark.executor.TaskMetrics;
45+
import org.apache.spark.network.util.LimitedInputStream;
46+
import org.apache.spark.scheduler.MapStatus;
47+
import org.apache.spark.serializer.DeserializationStream;
48+
import org.apache.spark.serializer.KryoSerializer;
49+
import org.apache.spark.serializer.Serializer;
4750
import org.apache.spark.serializer.SerializerInstance;
51+
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
4852
import org.apache.spark.shuffle.ShuffleMemoryManager;
4953
import org.apache.spark.storage.*;
5054
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
5155
import org.apache.spark.unsafe.memory.MemoryAllocator;
5256
import org.apache.spark.unsafe.memory.TaskMemoryManager;
5357
import org.apache.spark.util.Utils;
54-
import org.apache.spark.serializer.KryoSerializer;
55-
import org.apache.spark.scheduler.MapStatus;
5658

5759
public class UnsafeShuffleWriterSuite {
5860

@@ -64,6 +66,7 @@ public class UnsafeShuffleWriterSuite {
6466
File tempDir;
6567
long[] partitionSizesInMergedFile;
6668
final LinkedList<File> spillFilesCreated = new LinkedList<File>();
69+
final Serializer serializer = new KryoSerializer(new SparkConf());
6770

6871
@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
6972
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -147,8 +150,7 @@ public Tuple2<TempLocalBlockId, File> answer(
147150

148151
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
149152

150-
when(shuffleDep.serializer()).thenReturn(
151-
Option.<Serializer>apply(new KryoSerializer(new SparkConf())));
153+
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
152154
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
153155
}
154156

@@ -174,6 +176,27 @@ private void assertSpillFilesWereCleanedUp() {
174176
}
175177
}
176178

179+
private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
180+
final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
181+
long startOffset = 0;
182+
for (int i = 0; i < NUM_PARTITITONS; i++) {
183+
final long partitionSize = partitionSizesInMergedFile[i];
184+
if (partitionSize > 0) {
185+
InputStream in = new FileInputStream(mergedOutputFile);
186+
ByteStreams.skipFully(in, startOffset);
187+
DeserializationStream recordsStream = serializer.newInstance().deserializeStream(
188+
new LimitedInputStream(in, partitionSize));
189+
Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
190+
while (records.hasNext()) {
191+
recordsList.add(records.next());
192+
}
193+
recordsStream.close();
194+
startOffset += partitionSize;
195+
}
196+
}
197+
return recordsList;
198+
}
199+
177200
@Test(expected=IllegalStateException.class)
178201
public void mustCallWriteBeforeSuccessfulStop() {
179202
createWriter(false).stop(true);
@@ -215,19 +238,26 @@ public void writeWithoutSpilling() throws Exception {
215238
sumOfPartitionSizes += size;
216239
}
217240
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
218-
241+
Assert.assertEquals(
242+
HashMultiset.create(dataToWrite),
243+
HashMultiset.create(readRecordsFromFile()));
219244
assertSpillFilesWereCleanedUp();
220245
}
221246

222247
private void testMergingSpills(boolean transferToEnabled) throws IOException {
223248
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
224-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
225-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
226-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(3, 3));
227-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
249+
final ArrayList<Product2<Object, Object>> dataToWrite =
250+
new ArrayList<Product2<Object, Object>>();
251+
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
252+
dataToWrite.add(new Tuple2<Object, Object>(i, i));
253+
}
254+
writer.insertRecordIntoSorter(dataToWrite.get(0));
255+
writer.insertRecordIntoSorter(dataToWrite.get(1));
256+
writer.insertRecordIntoSorter(dataToWrite.get(2));
257+
writer.insertRecordIntoSorter(dataToWrite.get(3));
228258
writer.forceSorterToSpill();
229-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
230-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
259+
writer.insertRecordIntoSorter(dataToWrite.get(4));
260+
writer.insertRecordIntoSorter(dataToWrite.get(5));
231261
writer.closeAndWriteOutput();
232262
final Option<MapStatus> mapStatus = writer.stop(true);
233263
Assert.assertTrue(mapStatus.isDefined());
@@ -239,7 +269,9 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException {
239269
sumOfPartitionSizes += size;
240270
}
241271
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
242-
272+
Assert.assertEquals(
273+
HashMultiset.create(dataToWrite),
274+
HashMultiset.create(readRecordsFromFile()));
243275
assertSpillFilesWereCleanedUp();
244276
}
245277

@@ -263,7 +295,4 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
263295
writer.stop(false);
264296
assertSpillFilesWereCleanedUp();
265297
}
266-
267-
// TODO: actually try to read the shuffle output?
268-
269298
}

0 commit comments

Comments
 (0)