Skip to content

Commit db44a30

Browse files
committed
JIT hax.
1 parent 3868f6c commit db44a30

File tree

5 files changed

+336
-4
lines changed

5 files changed

+336
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,335 @@ class JoinedRow extends Row {
153153
s"[${row.mkString(",")}]"
154154
}
155155
}
156+
157+
/**
158+
* JIT HACK: Replace with macros
159+
*/
160+
class JoinedRow2 extends Row {
161+
private[this] var row1: Row = _
162+
private[this] var row2: Row = _
163+
164+
def this(left: Row, right: Row) = {
165+
this()
166+
row1 = left
167+
row2 = right
168+
}
169+
170+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
171+
def apply(r1: Row, r2: Row): Row = {
172+
row1 = r1
173+
row2 = r2
174+
this
175+
}
176+
177+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
178+
def withLeft(newLeft: Row): Row = {
179+
row1 = newLeft
180+
this
181+
}
182+
183+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
184+
def withRight(newRight: Row): Row = {
185+
row2 = newRight
186+
this
187+
}
188+
189+
def iterator = row1.iterator ++ row2.iterator
190+
191+
def length = row1.length + row2.length
192+
193+
def apply(i: Int) =
194+
if (i < row1.size) row1(i) else row2(i - row1.size)
195+
196+
def isNullAt(i: Int) =
197+
if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
198+
199+
def getInt(i: Int): Int =
200+
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
201+
202+
def getLong(i: Int): Long =
203+
if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
204+
205+
def getDouble(i: Int): Double =
206+
if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
207+
208+
def getBoolean(i: Int): Boolean =
209+
if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
210+
211+
def getShort(i: Int): Short =
212+
if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
213+
214+
def getByte(i: Int): Byte =
215+
if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
216+
217+
def getFloat(i: Int): Float =
218+
if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
219+
220+
def getString(i: Int): String =
221+
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
222+
223+
def copy() = {
224+
val totalSize = row1.size + row2.size
225+
val copiedValues = new Array[Any](totalSize)
226+
var i = 0
227+
while(i < totalSize) {
228+
copiedValues(i) = apply(i)
229+
i += 1
230+
}
231+
new GenericRow(copiedValues)
232+
}
233+
234+
override def toString() = {
235+
val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
236+
s"[${row.mkString(",")}]"
237+
}
238+
}
239+
240+
/**
241+
* JIT HACK: Replace with macros
242+
*/
243+
class JoinedRow3 extends Row {
244+
private[this] var row1: Row = _
245+
private[this] var row2: Row = _
246+
247+
def this(left: Row, right: Row) = {
248+
this()
249+
row1 = left
250+
row2 = right
251+
}
252+
253+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
254+
def apply(r1: Row, r2: Row): Row = {
255+
row1 = r1
256+
row2 = r2
257+
this
258+
}
259+
260+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
261+
def withLeft(newLeft: Row): Row = {
262+
row1 = newLeft
263+
this
264+
}
265+
266+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
267+
def withRight(newRight: Row): Row = {
268+
row2 = newRight
269+
this
270+
}
271+
272+
def iterator = row1.iterator ++ row2.iterator
273+
274+
def length = row1.length + row2.length
275+
276+
def apply(i: Int) =
277+
if (i < row1.size) row1(i) else row2(i - row1.size)
278+
279+
def isNullAt(i: Int) =
280+
if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
281+
282+
def getInt(i: Int): Int =
283+
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
284+
285+
def getLong(i: Int): Long =
286+
if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
287+
288+
def getDouble(i: Int): Double =
289+
if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
290+
291+
def getBoolean(i: Int): Boolean =
292+
if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
293+
294+
def getShort(i: Int): Short =
295+
if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
296+
297+
def getByte(i: Int): Byte =
298+
if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
299+
300+
def getFloat(i: Int): Float =
301+
if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
302+
303+
def getString(i: Int): String =
304+
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
305+
306+
def copy() = {
307+
val totalSize = row1.size + row2.size
308+
val copiedValues = new Array[Any](totalSize)
309+
var i = 0
310+
while(i < totalSize) {
311+
copiedValues(i) = apply(i)
312+
i += 1
313+
}
314+
new GenericRow(copiedValues)
315+
}
316+
317+
override def toString() = {
318+
val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
319+
s"[${row.mkString(",")}]"
320+
}
321+
}
322+
323+
/**
324+
* JIT HACK: Replace with macros
325+
*/
326+
class JoinedRow4 extends Row {
327+
private[this] var row1: Row = _
328+
private[this] var row2: Row = _
329+
330+
def this(left: Row, right: Row) = {
331+
this()
332+
row1 = left
333+
row2 = right
334+
}
335+
336+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
337+
def apply(r1: Row, r2: Row): Row = {
338+
row1 = r1
339+
row2 = r2
340+
this
341+
}
342+
343+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
344+
def withLeft(newLeft: Row): Row = {
345+
row1 = newLeft
346+
this
347+
}
348+
349+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
350+
def withRight(newRight: Row): Row = {
351+
row2 = newRight
352+
this
353+
}
354+
355+
def iterator = row1.iterator ++ row2.iterator
356+
357+
def length = row1.length + row2.length
358+
359+
def apply(i: Int) =
360+
if (i < row1.size) row1(i) else row2(i - row1.size)
361+
362+
def isNullAt(i: Int) =
363+
if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
364+
365+
def getInt(i: Int): Int =
366+
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
367+
368+
def getLong(i: Int): Long =
369+
if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
370+
371+
def getDouble(i: Int): Double =
372+
if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
373+
374+
def getBoolean(i: Int): Boolean =
375+
if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
376+
377+
def getShort(i: Int): Short =
378+
if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
379+
380+
def getByte(i: Int): Byte =
381+
if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
382+
383+
def getFloat(i: Int): Float =
384+
if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
385+
386+
def getString(i: Int): String =
387+
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
388+
389+
def copy() = {
390+
val totalSize = row1.size + row2.size
391+
val copiedValues = new Array[Any](totalSize)
392+
var i = 0
393+
while(i < totalSize) {
394+
copiedValues(i) = apply(i)
395+
i += 1
396+
}
397+
new GenericRow(copiedValues)
398+
}
399+
400+
override def toString() = {
401+
val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
402+
s"[${row.mkString(",")}]"
403+
}
404+
}
405+
406+
/**
407+
* JIT HACK: Replace with macros
408+
*/
409+
class JoinedRow5 extends Row {
410+
private[this] var row1: Row = _
411+
private[this] var row2: Row = _
412+
413+
def this(left: Row, right: Row) = {
414+
this()
415+
row1 = left
416+
row2 = right
417+
}
418+
419+
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
420+
def apply(r1: Row, r2: Row): Row = {
421+
row1 = r1
422+
row2 = r2
423+
this
424+
}
425+
426+
/** Updates this JoinedRow by updating its left base row. Returns itself. */
427+
def withLeft(newLeft: Row): Row = {
428+
row1 = newLeft
429+
this
430+
}
431+
432+
/** Updates this JoinedRow by updating its right base row. Returns itself. */
433+
def withRight(newRight: Row): Row = {
434+
row2 = newRight
435+
this
436+
}
437+
438+
def iterator = row1.iterator ++ row2.iterator
439+
440+
def length = row1.length + row2.length
441+
442+
def apply(i: Int) =
443+
if (i < row1.size) row1(i) else row2(i - row1.size)
444+
445+
def isNullAt(i: Int) =
446+
if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size)
447+
448+
def getInt(i: Int): Int =
449+
if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size)
450+
451+
def getLong(i: Int): Long =
452+
if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size)
453+
454+
def getDouble(i: Int): Double =
455+
if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size)
456+
457+
def getBoolean(i: Int): Boolean =
458+
if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size)
459+
460+
def getShort(i: Int): Short =
461+
if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size)
462+
463+
def getByte(i: Int): Byte =
464+
if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size)
465+
466+
def getFloat(i: Int): Float =
467+
if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)
468+
469+
def getString(i: Int): String =
470+
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
471+
472+
def copy() = {
473+
val totalSize = row1.size + row2.size
474+
val copiedValues = new Array[Any](totalSize)
475+
var i = 0
476+
while(i < totalSize) {
477+
copiedValues(i) = apply(i)
478+
i += 1
479+
}
480+
new GenericRow(copiedValues)
481+
}
482+
483+
override def toString() = {
484+
val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]())
485+
s"[${row.mkString(",")}]"
486+
}
487+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ case class Aggregate(
175175
private[this] val resultProjection =
176176
new InterpretedMutableProjection(
177177
resultExpressions, computedSchema ++ namedGroups.map(_._2))
178-
private[this] val joinedRow = new JoinedRow
178+
private[this] val joinedRow = new JoinedRow4
179179

180180
override final def hasNext: Boolean = hashTableIter.hasNext
181181

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ case class GeneratedAggregate(
185185
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
186186
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
187187

188-
val joinedRow = new JoinedRow
188+
val joinedRow = new JoinedRow3
189189

190190
if (groupingExpressions.isEmpty) {
191191
// TODO: Codegening anything other than the updateProjection is probably over kill.

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ trait HashJoin {
9292
private[this] var currentMatchPosition: Int = -1
9393

9494
// Mutable per row objects.
95-
private[this] val joinRow = new JoinedRow
95+
private[this] val joinRow = new JoinedRow2
9696

9797
private[this] val joinKeys = streamSideKeyGenerator()
9898

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ case class ParquetTableScan(
139139
partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
140140

141141
new Iterator[Row] {
142-
private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null)
142+
private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
143143

144144
def hasNext = iter.hasNext
145145

0 commit comments

Comments
 (0)