Skip to content

Commit eb4d642

Browse files
authored
Merge pull request #36 from ThoughtWorksInc/pending-buffer
Use PendingBuffer instead of OpenCL.Buffer
2 parents 7e54f28 + 8cd868b commit eb4d642

File tree

9 files changed

+182
-29
lines changed

9 files changed

+182
-29
lines changed

AsynchronousSemaphore/build.sbt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.2.10"
2+
3+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" % Test
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package com.thoughtworks.deeplearning
2+
3+
import java.util.concurrent.atomic.AtomicReference
4+
5+
import scala.annotation.tailrec
6+
import scala.collection.immutable.Queue
7+
import scalaz.{ContT, Trampoline}
8+
import scalaz.Free.Trampoline
9+
10+
object AsynchronousSemaphore {
11+
sealed trait State
12+
final case class Available(restNumberOfPermits: Int) extends State
13+
final case class Unavailable(waiters: Queue[Unit => Trampoline[Unit]]) extends State
14+
15+
@inline
16+
def apply(numberOfPermits: Int): AsynchronousSemaphore = {
17+
numberOfPermits.ensuring(_ > 0)
18+
new AtomicReference[State](Available(numberOfPermits)) with AsynchronousSemaphore {
19+
override protected def state: AtomicReference[State] = this
20+
}
21+
}
22+
}
23+
24+
/**
25+
* @author 杨博 (Yang Bo) <pop.atry@gmail.com>
26+
*/
27+
trait AsynchronousSemaphore {
28+
import AsynchronousSemaphore._
29+
protected def state: AtomicReference[State]
30+
31+
final def acquire(): ContT[Trampoline, Unit, Unit] = {
32+
ContT[Trampoline, Unit, Unit]({ waiter: (Unit => Trampoline[Unit]) =>
33+
@tailrec
34+
def retry(): Trampoline[Unit] = {
35+
state.get() match {
36+
case oldState @ Available(1) =>
37+
if (state.compareAndSet(oldState, Unavailable(Queue.empty))) {
38+
waiter(())
39+
} else {
40+
retry()
41+
}
42+
case oldState @ Available(restNumberOfPermits) if restNumberOfPermits > 1 =>
43+
if (state.compareAndSet(oldState, Available(restNumberOfPermits - 1))) { // TODO
44+
waiter(())
45+
} else {
46+
retry()
47+
}
48+
case oldState @ Unavailable(waiters) =>
49+
if (state.compareAndSet(oldState, Unavailable(waiters.enqueue(waiter)))) {
50+
Trampoline.done(())
51+
} else {
52+
retry()
53+
}
54+
}
55+
}
56+
retry()
57+
})
58+
}
59+
60+
@tailrec
61+
final def release(): Trampoline[Unit] = {
62+
state.get() match {
63+
case oldState @ Unavailable(waiters) =>
64+
val (head, tail) = waiters.dequeue
65+
if (state.compareAndSet(oldState, Unavailable(tail))) {
66+
head(())
67+
} else {
68+
release()
69+
}
70+
case oldState @ Available(restNumberOfPermits) =>
71+
if (state.compareAndSet(oldState, Available(restNumberOfPermits + 1))) {
72+
Trampoline.done(())
73+
} else {
74+
release()
75+
}
76+
}
77+
}
78+
}

DifferentiableKernel/src/main/scala/com/thoughtworks/deeplearning/DifferentiableKernel.scala

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
package com.thoughtworks.deeplearning
22

3+
import java.util.concurrent.Semaphore
4+
5+
import com.thoughtworks.deeplearning.Closeables.{AssertionAutoCloseable, AssertionFinalizer}
36
import com.thoughtworks.deeplearning.DifferentiableKernel.Zero.FloatZero
47
import com.thoughtworks.deeplearning.Memory.Address
58
import com.thoughtworks.deeplearning.OpenCL.CommandQueue.GlobalWorkSizeOnlyDimension
69
import com.thoughtworks.deeplearning.OpenCL.{CommandQueue, Device, Kernel}
710
import com.thoughtworks.deeplearning.OpenCLCodeGenerator.DslType.{DslBuffer, DslDouble, DslFloat, DslInt}
811
import com.thoughtworks.deeplearning.OpenCLCodeGenerator._
912
import com.thoughtworks.each.Monadic._
10-
import com.thoughtworks.raii.RAIITask
13+
import com.thoughtworks.raii.ResourceFactoryT.ResourceT
14+
import com.thoughtworks.raii.{RAIITask, ResourceFactoryT}
1115
import shapeless.labelled._
1216
import shapeless._
1317

1418
import scala.concurrent.ExecutionContext
1519
import scala.util.control.NonFatal
16-
import scalaz.{@@, Monad, Monoid}
20+
import scalaz.{@@, Monad, Monoid, \/, \/-}
1721
import scalaz.Tags.{Multiplication, Parallel}
1822
import scalaz.concurrent.Future
1923
import scalaz.concurrent.Future.{ParallelFuture, futureParallelApplicativeInstance}
@@ -24,13 +28,15 @@ import scala.language.higherKinds
2428

2529
object DifferentiableKernel {
2630

31+
final case class PendingBuffer[Element](buffer: OpenCL.Buffer[Element], events: List[OpenCL.Event])
32+
2733
private[DifferentiableKernel] trait StaticDslTypeExtractor {
2834
type AbstractType[A] <: DslType
2935

3036
implicit def dslDouble: AbstractType[Double]
3137
implicit def dslFloat: AbstractType[Float]
3238
implicit def dslInt: AbstractType[Int]
33-
implicit def dslBuffer[Element: AbstractType]: AbstractType[OpenCL.Buffer[Element]]
39+
implicit def dslBuffer[Element: AbstractType]: AbstractType[PendingBuffer[Element]]
3440
}
3541

3642
private[DifferentiableKernel] trait StaticDslExpressionExtractor {
@@ -73,14 +79,14 @@ object DifferentiableKernel {
7379

7480
import OpenCLLayer._
7581

76-
def compile(context: OpenCL.Context, device: Device, commandQueue: CommandQueue)(
82+
def compile(context: OpenCL.Context, device: Device, commandQueue: CommandQueue, semaphore: AsynchronousSemaphore)(
7783
implicit compiler: Compiler[OutputElementData, OutputElementDelta, LocalDelta],
7884
outputDataMemory: Memory[OutputElementData],
7985
outputDeltaMemory: Memory[OutputElementDelta],
8086
outputDataType: StaticDslType[OutputElementData],
8187
outputDeltaType: StaticDslType[OutputElementDelta],
8288
executor: ExecutionContext): RAIITask[(Int, compiler.ParameterRecord) => RAIITask[
83-
Tape.Aux[OpenCL.Buffer[OutputElementData], OpenCL.Buffer[OutputElementDelta]]]] = throwableMonadic[RAIITask] {
89+
Tape.Aux[PendingBuffer[OutputElementData], PendingBuffer[OutputElementDelta]]]] = throwableMonadic[RAIITask] {
8490

8591
RAIITask.jump().each
8692

@@ -101,20 +107,32 @@ object DifferentiableKernel {
101107
{ (expectedSize: Int, inputParameterMap: compiler.ParameterRecord) =>
102108
throwableMonadic[RAIITask] {
103109
val kernel = forwardKernelTask.each
104-
val outputBuffer =
105-
RAIITask.managed(context.createBuffer[OutputElementData](expectedSize)(outputDataMemory)).each
110+
val outputBuffer = context.createBuffer[OutputElementData](expectedSize)(outputDataMemory)
111+
106112
compiler.setKernelInputArguments(kernel, 1, inputParameterMap)
107113
kernel.setArg(0, outputBuffer)
108-
val event =
114+
115+
RAIITask.unmanaged(semaphore.acquire()).each
116+
val event = try {
109117
RAIITask
110118
.managed(
111119
commandQueue.enqueueNDRangeKernel(kernel, Seq(GlobalWorkSizeOnlyDimension(Address(expectedSize)))))
112120
.each
121+
} catch {
122+
case e if NonFatal(e) =>
123+
semaphore.release().run
124+
(throw e): OpenCL.Event
125+
}
126+
event.waitForComplete().unsafePerformAsync { _ =>
127+
semaphore.release().run
128+
}
129+
113130
RAIITask.unmanaged(event.waitForComplete()).each
114131
new Tape {
115-
override def data: OpenCL.Buffer[OutputElementData] = outputBuffer
132+
// borrow
133+
override val data: PendingBuffer[OutputElementData] = PendingBuffer(outputBuffer, List(event))
116134

117-
override def backward[OutputDeltaBuffer <: OpenCL.Buffer[OutputElementDelta]](
135+
override def backward[OutputDeltaBuffer <: PendingBuffer[OutputElementDelta]](
118136
outputDeltaTask: RAIITask[OutputDeltaBuffer]): Future[Unit] = {
119137
Future.suspend {
120138
Future.now(()) // TODO: backward
@@ -123,9 +141,9 @@ object DifferentiableKernel {
123141
}
124142

125143
// TODO: Change OutputData and OutputDelta to a pair of OpenCL.Buffer and OpenCL.Event
126-
override type Data = OpenCL.Buffer[OutputElementData]
127-
override type Delta = OpenCL.Buffer[OutputElementDelta]
128-
}: Tape.Aux[OpenCL.Buffer[OutputElementData], OpenCL.Buffer[OutputElementDelta]]
144+
override type Data = PendingBuffer[OutputElementData]
145+
override type Delta = PendingBuffer[OutputElementDelta]
146+
}: Tape.Aux[PendingBuffer[OutputElementData], PendingBuffer[OutputElementDelta]]
129147
}
130148
}
131149
}
@@ -291,10 +309,10 @@ object DifferentiableKernel {
291309
}
292310

293311
def bufferIdentifier[Data, Delta](
294-
key: Witness): OpenCLLayer[OpenCL.Buffer[Data],
295-
OpenCL.Buffer[Delta],
312+
key: Witness): OpenCLLayer[PendingBuffer[Data],
313+
PendingBuffer[Delta],
296314
FieldType[key.T, JacobianMatrix[Data, Delta]] :: HNil] = {
297-
OpenCLLayer[OpenCL.Buffer[Data], OpenCL.Buffer[Delta], FieldType[key.T, JacobianMatrix[Data, Delta]] :: HNil](
315+
OpenCLLayer[PendingBuffer[Data], PendingBuffer[Delta], FieldType[key.T, JacobianMatrix[Data, Delta]] :: HNil](
298316
StaticDslExpression(DslExpression.Identifier(key.value)),
299317
field[key.T](JacobianMatrix.Identity[Data, Delta]()) :: HNil
300318
)
@@ -314,7 +332,7 @@ object DifferentiableKernel {
314332
IndexLocalDelta <: HList,
315333
ElementLocalDelta <: HList,
316334
LocalDelta <: HList](
317-
buffer: OpenCLLayer[OpenCL.Buffer[ElementData], OpenCL.Buffer[ElementDelta], BufferLocalDelta],
335+
buffer: OpenCLLayer[PendingBuffer[ElementData], PendingBuffer[ElementDelta], BufferLocalDelta],
318336
index: OpenCLLayer[Int, Float, IndexLocalDelta])(
319337
implicit elementDataType: StaticDslType[ElementData],
320338
zero: Zero.Aux[IndexLocalDelta],
@@ -355,6 +373,8 @@ object DifferentiableKernel {
355373
def forwardParameter: Parameter
356374

357375
def setArgument(kernel: Kernel, index: Int, input: Input): Unit
376+
377+
def borrowEvents(input: Input): List[OpenCL.Event]
358378
}
359379

360380
object InputCompiler {
@@ -368,14 +388,19 @@ object DifferentiableKernel {
368388
elementDataType: StaticDslType[InputElementData])
369389
: InputCompiler.Aux[Key,
370390
JacobianMatrix.Row[InputElementData, InputElementDelta],
371-
Tape.Aux[OpenCL.Buffer[InputElementData], OpenCL.Buffer[InputElementDelta]]] =
391+
Tape.Aux[PendingBuffer[InputElementData], PendingBuffer[InputElementDelta]]] =
372392
new InputCompiler[Key, JacobianMatrix.Row[InputElementData, InputElementDelta]] {
373393

374-
override type Input = Tape.Aux[OpenCL.Buffer[InputElementData], OpenCL.Buffer[InputElementDelta]]
394+
override type Input = Tape.Aux[PendingBuffer[InputElementData], PendingBuffer[InputElementDelta]]
375395
override def forwardParameter: Parameter = Parameter(witness.value, DslType.DslBuffer(elementDataType))
376396

377397
override def setArgument(kernel: Kernel, index: Int, input: Input): Unit = {
378-
kernel.setArg[OpenCL.Buffer[InputElementData]](index, input.data)
398+
kernel.setArg[OpenCL.Buffer[InputElementData]](index, input.data.buffer)
399+
}
400+
401+
override def borrowEvents(
402+
input: Tape.Aux[PendingBuffer[InputElementData], PendingBuffer[InputElementDelta]]): List[OpenCL.Event] = {
403+
input.data.events
379404
}
380405
}
381406

@@ -395,6 +420,8 @@ object DifferentiableKernel {
395420
Parameter(OutputId, DslType.DslBuffer(outputDataType)) :: forwardInputParameters
396421

397422
def setKernelInputArguments(kernel: Kernel, startIndex: Int, parameters: ParameterRecord)
423+
424+
def borrowEvents(parameters: ParameterRecord): List[OpenCL.Event]
398425
}
399426

400427
object Compiler {
@@ -415,6 +442,8 @@ object DifferentiableKernel {
415442
override def forwardInputParameters: Nil.type = Nil
416443

417444
override def setKernelInputArguments(kernel: Kernel, startIndex: Int, parameters: HNil): Unit = {}
445+
446+
override def borrowEvents(parameters: HNil): List[OpenCL.Event] = Nil
418447
}
419448

420449
implicit def hconsFill[OutputElementData,
@@ -441,6 +470,10 @@ object DifferentiableKernel {
441470
headInputCompiler.setArgument(kernel, startIndex, parameters.head)
442471
tailCompiler.setKernelInputArguments(kernel, startIndex + 1, parameters.tail)
443472
}
473+
474+
override def borrowEvents(parameters: ::[FieldType[Key, Input], TailParameterRecord]): List[OpenCL.Event] = {
475+
headInputCompiler.borrowEvents(parameters.head) ::: tailCompiler.borrowEvents(parameters.tail)
476+
}
444477
}
445478

446479
}

DifferentiableKernel/src/test/scala/com/thoughtworks/deeplearning/DifferentiableKernelSpec.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class DifferentiableKernelSpec extends AsyncFreeSpec with Matchers {
5656

5757
(context, commandQueue)
5858
}
59+
val semaphore = AsynchronousSemaphore(3)
5960

6061
"When fill a buffer with 42.0f" - {
6162
val differentiableKernel = {
@@ -71,12 +72,14 @@ class DifferentiableKernelSpec extends AsyncFreeSpec with Matchers {
7172
RAIITask.unmanaged(
7273
RAIITask.run(
7374
throwableMonadic[RAIITask] {
74-
val layer = differentiableKernel.compile(context, device, commandQueue).each
75+
val layer = differentiableKernel.compile(context, device, commandQueue, semaphore).each
7576
val outputTape = layer(1, HNil).each
76-
val delta = RAIITask.managed(context.createBuffer[Float](1))
77+
val delta = RAIITask.managed(context.createBuffer[Float](1)).map(PendingBuffer(_, Nil))
7778
RAIITask.unmanaged(outputTape.backward(delta)).each
7879
val f = BufferUtils.createFloatBuffer(1)
79-
val event = RAIITask.managed(commandQueue.enqueueReadBuffer(outputTape.data, f)).each
80+
val event = RAIITask
81+
.managed(commandQueue.enqueueReadBuffer(outputTape.data.buffer, f, outputTape.data.events: _*))
82+
.each
8083
RAIITask.unmanaged(event.waitForComplete()).each
8184
f
8285
}
@@ -119,12 +122,14 @@ class DifferentiableKernelSpec extends AsyncFreeSpec with Matchers {
119122
RAIITask.unmanaged(
120123
RAIITask.run(
121124
throwableMonadic[RAIITask] {
122-
val layer = differentiableKernel.compile(context, device, commandQueue).each
125+
val layer = differentiableKernel.compile(context, device, commandQueue, semaphore).each
123126
val outputTape = layer(1, ??? :: HNil).each
124-
val delta = RAIITask.managed(context.createBuffer[Float](1))
127+
val delta = RAIITask.managed(context.createBuffer[Float](1)).map(PendingBuffer(_, Nil))
125128
RAIITask.unmanaged(outputTape.backward(delta)).each
126129
val f = BufferUtils.createFloatBuffer(1)
127-
val event = RAIITask.managed(commandQueue.enqueueReadBuffer(outputTape.data, f)).each
130+
val event = RAIITask
131+
.managed(commandQueue.enqueueReadBuffer(outputTape.data.buffer, f, outputTape.data.events: _*))
132+
.each
128133
RAIITask.unmanaged(event.waitForComplete()).each
129134
f
130135
}

FutureIsomorphism/build.sbt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
addCompilerPlugin("org.spire-math" %% "kind-projector" % "0.9.3")
2+
3+
libraryDependencies += "org.scalaz" %% "scalaz-concurrent" % "7.2.10"
4+
5+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.1" % Test
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package com.thoughtworks.deeplearning
2+
3+
import scalaz.{ContT, Trampoline}
4+
import scalaz.Free.Trampoline
5+
import scalaz.concurrent.Future
6+
7+
/**
8+
* @author 杨博 (Yang Bo) &lt;pop.atry@gmail.com&gt;
9+
*/
10+
object FutureIsomorphism extends scalaz.Isomorphism.IsoFunctorTemplate[Future, ContT[Trampoline, Unit, ?]] {
11+
override def to[A](fa: Future[A]): ContT[Trampoline, Unit, A] = ContT[Trampoline, Unit, A] { continue =>
12+
Trampoline.delay(fa.unsafePerformListen(continue))
13+
}
14+
15+
override def from[A](ga: ContT[Trampoline, Unit, A]): Future[A] = {
16+
Future.Async { continue =>
17+
ga(continue).run
18+
}
19+
}
20+
}

OpenCL/src/main/scala/com/thoughtworks/deeplearning/OpenCL.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,10 @@ data: $data""")
495495

496496
final class Event(val handle: Address) extends AssertionAutoCloseable with AssertionFinalizer {
497497

498-
def duplicate = new Event(handle)
498+
def duplicate(): Event = {
499+
checkErrorCode(clRetainEvent(handle.toLong))
500+
new Event(handle)
501+
}
499502

500503
override protected def forceClose(): Unit = {
501504
checkErrorCode(clReleaseEvent(handle.toLong))

build.sbt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ lazy val DifferentiableKernel =
55
project.dependsOn(
66
OpenCL,
77
OpenCLCodeGenerator,
8-
TapeTaskFactory
8+
TapeTaskFactory,
9+
FutureIsomorphism,
10+
AsynchronousSemaphore
911
)
1012

1113
lazy val OpenCLCodeGenerator = project.dependsOn(Memory)
@@ -40,6 +42,10 @@ lazy val TapeTask = project.dependsOn(Tape, ProjectRef(file("RAII.scala"), "RAII
4042

4143
lazy val LogRecords = project
4244

45+
lazy val AsynchronousSemaphore = project
46+
47+
lazy val FutureIsomorphism = project
48+
4349
//lazy val DifferentiableDouble =
4450
// project.dependsOn(Layer,
4551
// CumulativeTape,

0 commit comments

Comments
 (0)