Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,32 @@
package org.apache.spark.sql.auron.memory

import java.nio.ByteBuffer
import java.util.concurrent.locks.ReentrantLock

import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryConsumer

case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging {
private var spillBuf: SpillBuf = new MemBasedSpillBuf
private val lock = new ReentrantLock

def memUsed: Long = spillBuf.memUsed
def diskUsed: Long = spillBuf.diskUsed
def size: Long = spillBuf.size
def diskIOTime: Long = spillBuf.diskIOTime

private def withLock[T](f: => T): T = {
lock.lock()
try {
f
} finally {
lock.unlock()
}
}

def write(buf: ByteBuffer): Unit = {
var needSpill = false
synchronized {
withLock {
spillBuf match {
case _: MemBasedSpillBuf =>
val acquiredMemory = hsm.acquireMemory(buf.capacity())
Expand All @@ -46,13 +58,13 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging {
spillInternal()
}

synchronized {
withLock {
spillBuf.write(buf)
}
}

def read(buf: ByteBuffer): Int = {
synchronized {
withLock {
val oldMemUsed = memUsed
val startPosition = buf.position()
spillBuf.read(buf)
Expand All @@ -69,7 +81,7 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging {
}

def release(): Unit = {
synchronized {
withLock {
val oldMemUsed = memUsed
spillBuf = new ReleasedSpillBuf(spillBuf)

Expand All @@ -79,8 +91,20 @@ case class OnHeapSpill(hsm: SparkOnHeapSpillManager, id: Int) extends Logging {
}
}

def spill(): Long = {
synchronized {
def spill(trigger: MemoryConsumer): Long = {
// this might have been locked if the spilling is triggered by OnHeapSpill.write
if (trigger == this.hsm) {
if (lock.tryLock()) {
try {
return spillInternal()
} finally {
lock.unlock()
}
}
return 0L
}

withLock {
spillInternal()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class SparkOnHeapSpillManager(taskContext: TaskContext)
val sortedSpills = spills.seq.sortBy(0 - _.map(_.memUsed).getOrElse(0L))
sortedSpills.foreach {
case Some(spill) if spill.memUsed > 0 =>
totalFreed += spill.spill()
totalFreed += spill.spill(trigger)
if (totalFreed >= size) {
return totalFreed
}
Expand Down
Loading