Skip to content

Commit e056332

Browse files
committed
SPARK-12619 Combine small files in a hadoop directory into single split
1 parent 1537e55 commit e056332

File tree

2 files changed

+319
-2
lines changed

2 files changed

+319
-2
lines changed
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.util;
19+
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
22+
23+
import java.io.DataInput;
24+
import java.io.DataOutput;
25+
import java.io.IOException;
26+
import java.util.ArrayList;
27+
import java.util.Arrays;
28+
import java.util.Collection;
29+
import java.util.HashSet;
30+
import java.util.Iterator;
31+
import java.util.List;
32+
import java.util.Set;
33+
34+
import org.apache.hadoop.conf.Configurable;
35+
import org.apache.hadoop.conf.Configuration;
36+
import org.apache.hadoop.io.ObjectWritable;
37+
import org.apache.hadoop.io.Text;
38+
import org.apache.hadoop.mapred.InputFormat;
39+
import org.apache.hadoop.mapred.InputSplit;
40+
import org.apache.hadoop.mapred.JobConf;
41+
import org.apache.hadoop.mapred.RecordReader;
42+
import org.apache.hadoop.mapred.Reporter;
43+
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
44+
import org.apache.hadoop.util.ReflectionUtils;
45+
46+
/**
47+
*/
48+
public class SimpleCombiner<K, V> implements InputFormat<K, V> {
49+
50+
private static final Logger LOG = LoggerFactory.getLogger(SimpleCombiner.class);
51+
52+
// apply only if the input format is file-based
53+
public static boolean accepts(Class<? extends InputFormat> inputFormatter) {
54+
return FileInputFormat.class.isAssignableFrom(inputFormatter) ||
55+
org.apache.hadoop.mapred.FileInputFormat.class.isAssignableFrom(inputFormatter);
56+
}
57+
58+
private final InputFormat<K, V> delegate;
59+
private final long threshold;
60+
61+
public SimpleCombiner(InputFormat<K, V> delegate, long threshold) {
62+
this.delegate = delegate;
63+
this.threshold = threshold;
64+
}
65+
66+
@Override
67+
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
68+
InputSplit[] original = delegate.getSplits(job, numSplits);
69+
LOG.info("Start combining " + original.length + " splits with threshold " + threshold);
70+
71+
long start = System.currentTimeMillis();
72+
List<TaggedPair<Set<String>, List<InputSplit>>> splits = new ArrayList<>();
73+
for (InputSplit split : original) {
74+
final long length = split.getLength();
75+
final String[] locations = split.getLocations();
76+
77+
boolean added = false;
78+
for (TaggedPair<Set<String>, List<InputSplit>> entry : splits) {
79+
if (entry.t > threshold) {
80+
continue;
81+
}
82+
Set<String> set = entry.l;
83+
if (containsAny(set, locations)) {
84+
set.retainAll(Arrays.asList(locations));
85+
entry.r.add(split);
86+
entry.t += length;
87+
added = true;
88+
break;
89+
}
90+
}
91+
if (!added) {
92+
splits.add(new TaggedPair<Set<String>, List<InputSplit>>(
93+
length,
94+
new HashSet<>(Arrays.asList(locations)),
95+
new ArrayList<>(Arrays.asList(split))));
96+
}
97+
}
98+
99+
List<InputSplit> combined = new ArrayList<>();
100+
Iterator<TaggedPair<Set<String>, List<InputSplit>>> iterator = splits.iterator();
101+
while (iterator.hasNext()) {
102+
TaggedPair<Set<String>, List<InputSplit>> entry = iterator.next();
103+
if (entry.t >= threshold) {
104+
combined.add(
105+
new InputSplits(entry.t,
106+
entry.l.toArray(new String[entry.l.size()]),
107+
entry.r.toArray(new InputSplit[entry.r.size()])));
108+
iterator.remove();
109+
}
110+
}
111+
112+
TaggedPair<Set<String>, List<InputSplit>> current = null;
113+
iterator = splits.iterator();
114+
while (iterator.hasNext()) {
115+
TaggedPair<Set<String>, List<InputSplit>> entry = iterator.next();
116+
if (current == null) {
117+
iterator.remove();
118+
current = entry;
119+
continue;
120+
}
121+
if (containsAny(current.l, entry.l)) {
122+
iterator.remove();
123+
current.t += entry.t;
124+
current.r.addAll(entry.r);
125+
current.l.retainAll(entry.l);
126+
}
127+
if (current.t > threshold) {
128+
combined.add(
129+
new InputSplits(current.t,
130+
current.l.toArray(new String[current.l.size()]),
131+
current.r.toArray(new InputSplit[current.r.size()])));
132+
current = null;
133+
}
134+
}
135+
if (current != null) {
136+
combined.add(
137+
new InputSplits(current.t,
138+
current.l.toArray(new String[current.l.size()]),
139+
current.r.toArray(new InputSplit[current.r.size()])));
140+
}
141+
for (TaggedPair<Set<String>, List<InputSplit>> entry : splits) {
142+
combined.add(new InputSplits(entry.t,
143+
entry.l.toArray(new String[entry.l.size()]),
144+
entry.r.toArray(new InputSplit[entry.r.size()])));
145+
}
146+
LOG.info("Combined to " + combined.size() + " splits, took " + (System.currentTimeMillis() - start) + " msec");
147+
return combined.toArray(new InputSplit[combined.size()]);
148+
}
149+
150+
private boolean containsAny(Set<String> set, String[] targets) {
151+
return containsAny(set, Arrays.asList(targets));
152+
}
153+
154+
private boolean containsAny(Set<String> set, Collection<String> targets) {
155+
for (String target : targets) {
156+
if (set.contains(target)) {
157+
return true;
158+
}
159+
}
160+
return set.isEmpty();
161+
}
162+
163+
@Override
164+
public RecordReader<K, V> getRecordReader(final InputSplit split, final JobConf job, final Reporter reporter)
165+
throws IOException {
166+
167+
final InputSplit[] splits = ((InputSplits) split).splits;
168+
169+
return new RecordReader<K, V>() {
170+
171+
private int index;
172+
private long pos;
173+
private RecordReader<K, V> reader = nextReader();
174+
175+
private RecordReader<K, V> nextReader() throws IOException {
176+
return delegate.getRecordReader(splits[index++], job, reporter);
177+
}
178+
179+
@Override
180+
@SuppressWarnings("unchecked")
181+
public boolean next(K key, V value) throws IOException {
182+
while (!reader.next(key, value)) {
183+
if (index < splits.length) {
184+
pos = reader.getPos();
185+
reader.close();
186+
reader = nextReader();
187+
continue;
188+
}
189+
return false;
190+
}
191+
return true;
192+
}
193+
194+
@Override
195+
public K createKey() {
196+
return reader.createKey();
197+
}
198+
199+
@Override
200+
public V createValue() {
201+
return reader.createValue();
202+
}
203+
204+
@Override
205+
public long getPos() throws IOException {
206+
return pos + reader.getPos();
207+
}
208+
209+
@Override
210+
public void close() throws IOException {
211+
reader.close();
212+
}
213+
214+
@Override
215+
public float getProgress() throws IOException {
216+
return (index - 1 + reader.getProgress()) / splits.length;
217+
}
218+
};
219+
}
220+
221+
public static class InputSplits implements InputSplit, Configurable {
222+
223+
private long length;
224+
private InputSplit[] splits;
225+
private String[] locations;
226+
227+
private transient Configuration conf;
228+
229+
public InputSplits() {
230+
}
231+
232+
public InputSplits(long length, String[] locations, InputSplit[] splits) {
233+
this.length = length;
234+
this.locations = locations;
235+
this.splits = splits;
236+
}
237+
238+
@Override
239+
public void setConf(Configuration conf) {
240+
this.conf = conf;
241+
}
242+
243+
@Override
244+
public Configuration getConf() {
245+
return conf;
246+
}
247+
248+
@Override
249+
public long getLength() throws IOException {
250+
return length;
251+
}
252+
253+
@Override
254+
public String[] getLocations() throws IOException {
255+
return locations;
256+
}
257+
258+
@Override
259+
public void write(DataOutput out) throws IOException {
260+
out.writeLong(length);
261+
out.writeInt(locations.length);
262+
for (String location : locations) {
263+
Text.writeString(out, location);
264+
}
265+
out.writeInt(splits.length);
266+
for (InputSplit split : splits) {
267+
Text.writeString(out, split.getClass().getName());
268+
split.write(out);
269+
}
270+
}
271+
272+
@Override
273+
public void readFields(DataInput in) throws IOException {
274+
length = in.readLong();
275+
locations = new String[in.readInt()];
276+
for (int i = 0; i < locations.length; i++) {
277+
locations[i] = Text.readString(in);
278+
}
279+
splits = new InputSplit[in.readInt()];
280+
try {
281+
for (int i = 0; i < splits.length; i++) {
282+
Class<?> clazz = ObjectWritable.loadClass(conf, Text.readString(in));
283+
splits[i] = (InputSplit) ReflectionUtils.newInstance(clazz, conf);
284+
splits[i].readFields(in);
285+
}
286+
} catch (Exception e) {
287+
throw new IOException(e);
288+
}
289+
}
290+
291+
@Override
292+
public String toString() {
293+
return "length = " + length + ", locations = " + Arrays.toString(locations) + ", splits = " + Arrays.toString(splits);
294+
}
295+
}
296+
297+
private static class TaggedPair<L, R> {
298+
private long t;
299+
private final L l;
300+
private final R r;
301+
302+
private TaggedPair(long t, L l, R r) {
303+
this.t = t;
304+
this.l = l;
305+
this.r = r;
306+
}
307+
308+
static <L, R> TaggedPair<L, R> of(long t, L l, R r) {
309+
return new TaggedPair<>(t, l, r);
310+
}
311+
}
312+
}

core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import org.apache.spark.broadcast.Broadcast
4545
import org.apache.spark.deploy.SparkHadoopUtil
4646
import org.apache.spark.executor.DataReadMethod
4747
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
48-
import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils}
48+
import org.apache.spark.util.{SerializableConfiguration, SimpleCombiner, ShutdownHookManager, NextIterator, Utils}
4949
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
5050
import org.apache.spark.storage.StorageLevel
5151

@@ -189,7 +189,12 @@ class HadoopRDD[K, V](
189189
case c: Configurable => c.setConf(conf)
190190
case _ =>
191191
}
192-
newInputFormat
192+
val threshold = conf.getLong("hive.split.combine.threshold", -1)
193+
if (threshold > 0 && SimpleCombiner.accepts(inputFormatClass)) {
194+
new SimpleCombiner(newInputFormat, threshold)
195+
} else {
196+
newInputFormat
197+
}
193198
}
194199

195200
override def getPartitions: Array[Partition] = {

0 commit comments

Comments
 (0)