17
17
18
18
package org .apache .spark .shuffle .unsafe ;
19
19
20
- import java .io .File ;
21
- import java .io .IOException ;
22
- import java .io .InputStream ;
23
- import java .io .OutputStream ;
20
+ import java .io .*;
24
21
import java .util .*;
25
22
26
23
import scala .*;
24
+ import scala .collection .Iterator ;
27
25
import scala .runtime .AbstractFunction1 ;
28
26
27
+ import com .google .common .collect .HashMultiset ;
28
+ import com .google .common .io .ByteStreams ;
29
29
import org .junit .After ;
30
30
import org .junit .Assert ;
31
31
import org .junit .Before ;
40
40
import static org .mockito .Mockito .*;
41
41
42
42
import org .apache .spark .*;
43
- import org .apache .spark .serializer .Serializer ;
44
- import org .apache .spark .shuffle .IndexShuffleBlockResolver ;
45
43
import org .apache .spark .executor .ShuffleWriteMetrics ;
46
44
import org .apache .spark .executor .TaskMetrics ;
45
+ import org .apache .spark .network .util .LimitedInputStream ;
46
+ import org .apache .spark .scheduler .MapStatus ;
47
+ import org .apache .spark .serializer .DeserializationStream ;
48
+ import org .apache .spark .serializer .KryoSerializer ;
49
+ import org .apache .spark .serializer .Serializer ;
47
50
import org .apache .spark .serializer .SerializerInstance ;
51
+ import org .apache .spark .shuffle .IndexShuffleBlockResolver ;
48
52
import org .apache .spark .shuffle .ShuffleMemoryManager ;
49
53
import org .apache .spark .storage .*;
50
54
import org .apache .spark .unsafe .memory .ExecutorMemoryManager ;
51
55
import org .apache .spark .unsafe .memory .MemoryAllocator ;
52
56
import org .apache .spark .unsafe .memory .TaskMemoryManager ;
53
57
import org .apache .spark .util .Utils ;
54
- import org .apache .spark .serializer .KryoSerializer ;
55
- import org .apache .spark .scheduler .MapStatus ;
56
58
57
59
public class UnsafeShuffleWriterSuite {
58
60
@@ -64,6 +66,7 @@ public class UnsafeShuffleWriterSuite {
64
66
File tempDir ;
65
67
long [] partitionSizesInMergedFile ;
66
68
final LinkedList <File > spillFilesCreated = new LinkedList <File >();
69
+ final Serializer serializer = new KryoSerializer (new SparkConf ());
67
70
68
71
@ Mock (answer = RETURNS_SMART_NULLS ) ShuffleMemoryManager shuffleMemoryManager ;
69
72
@ Mock (answer = RETURNS_SMART_NULLS ) BlockManager blockManager ;
@@ -147,8 +150,7 @@ public Tuple2<TempLocalBlockId, File> answer(
147
150
148
151
when (taskContext .taskMetrics ()).thenReturn (new TaskMetrics ());
149
152
150
- when (shuffleDep .serializer ()).thenReturn (
151
- Option .<Serializer >apply (new KryoSerializer (new SparkConf ())));
153
+ when (shuffleDep .serializer ()).thenReturn (Option .<Serializer >apply (serializer ));
152
154
when (shuffleDep .partitioner ()).thenReturn (hashPartitioner );
153
155
}
154
156
@@ -174,6 +176,27 @@ private void assertSpillFilesWereCleanedUp() {
174
176
}
175
177
}
176
178
179
+ private List <Tuple2 <Object , Object >> readRecordsFromFile () throws IOException {
180
+ final ArrayList <Tuple2 <Object , Object >> recordsList = new ArrayList <Tuple2 <Object , Object >>();
181
+ long startOffset = 0 ;
182
+ for (int i = 0 ; i < NUM_PARTITITONS ; i ++) {
183
+ final long partitionSize = partitionSizesInMergedFile [i ];
184
+ if (partitionSize > 0 ) {
185
+ InputStream in = new FileInputStream (mergedOutputFile );
186
+ ByteStreams .skipFully (in , startOffset );
187
+ DeserializationStream recordsStream = serializer .newInstance ().deserializeStream (
188
+ new LimitedInputStream (in , partitionSize ));
189
+ Iterator <Tuple2 <Object , Object >> records = recordsStream .asKeyValueIterator ();
190
+ while (records .hasNext ()) {
191
+ recordsList .add (records .next ());
192
+ }
193
+ recordsStream .close ();
194
+ startOffset += partitionSize ;
195
+ }
196
+ }
197
+ return recordsList ;
198
+ }
199
+
177
200
@ Test (expected =IllegalStateException .class )
178
201
public void mustCallWriteBeforeSuccessfulStop () {
179
202
createWriter (false ).stop (true );
@@ -215,19 +238,26 @@ public void writeWithoutSpilling() throws Exception {
215
238
sumOfPartitionSizes += size ;
216
239
}
217
240
Assert .assertEquals (mergedOutputFile .length (), sumOfPartitionSizes );
218
-
241
+ Assert .assertEquals (
242
+ HashMultiset .create (dataToWrite ),
243
+ HashMultiset .create (readRecordsFromFile ()));
219
244
assertSpillFilesWereCleanedUp ();
220
245
}
221
246
222
247
private void testMergingSpills (boolean transferToEnabled ) throws IOException {
223
248
final UnsafeShuffleWriter <Object , Object > writer = createWriter (transferToEnabled );
224
- writer .insertRecordIntoSorter (new Tuple2 <Object , Object >(1 , 1 ));
225
- writer .insertRecordIntoSorter (new Tuple2 <Object , Object >(2 , 2 ));
226
- writer .insertRecordIntoSorter (new Tuple2 <Object , Object >(3 , 3 ));
227
- writer .insertRecordIntoSorter (new Tuple2 <Object , Object >(4 , 4 ));
249
+ final ArrayList <Product2 <Object , Object >> dataToWrite =
250
+ new ArrayList <Product2 <Object , Object >>();
251
+ for (int i : new int [] { 1 , 2 , 3 , 4 , 4 , 2 }) {
252
+ dataToWrite .add (new Tuple2 <Object , Object >(i , i ));
253
+ }
254
+ writer .insertRecordIntoSorter (dataToWrite .get (0 ));
255
+ writer .insertRecordIntoSorter (dataToWrite .get (1 ));
256
+ writer .insertRecordIntoSorter (dataToWrite .get (2 ));
257
+ writer .insertRecordIntoSorter (dataToWrite .get (3 ));
228
258
writer .forceSorterToSpill ();
229
- writer .insertRecordIntoSorter (new Tuple2 < Object , Object >( 4 , 4 ));
230
- writer .insertRecordIntoSorter (new Tuple2 < Object , Object >( 2 , 2 ));
259
+ writer .insertRecordIntoSorter (dataToWrite . get ( 4 ));
260
+ writer .insertRecordIntoSorter (dataToWrite . get ( 5 ));
231
261
writer .closeAndWriteOutput ();
232
262
final Option <MapStatus > mapStatus = writer .stop (true );
233
263
Assert .assertTrue (mapStatus .isDefined ());
@@ -239,7 +269,9 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException {
239
269
sumOfPartitionSizes += size ;
240
270
}
241
271
Assert .assertEquals (mergedOutputFile .length (), sumOfPartitionSizes );
242
-
272
+ Assert .assertEquals (
273
+ HashMultiset .create (dataToWrite ),
274
+ HashMultiset .create (readRecordsFromFile ()));
243
275
assertSpillFilesWereCleanedUp ();
244
276
}
245
277
@@ -263,7 +295,4 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
263
295
writer .stop (false );
264
296
assertSpillFilesWereCleanedUp ();
265
297
}
266
-
267
- // TODO: actually try to read the shuffle output?
268
-
269
298
}
0 commit comments