1
1
package com .thoughtworks .deeplearning
2
2
3
+ import java .util .concurrent .Semaphore
4
+
5
+ import com .thoughtworks .deeplearning .Closeables .{AssertionAutoCloseable , AssertionFinalizer }
3
6
import com .thoughtworks .deeplearning .DifferentiableKernel .Zero .FloatZero
4
7
import com .thoughtworks .deeplearning .Memory .Address
5
8
import com .thoughtworks .deeplearning .OpenCL .CommandQueue .GlobalWorkSizeOnlyDimension
6
9
import com .thoughtworks .deeplearning .OpenCL .{CommandQueue , Device , Kernel }
7
10
import com .thoughtworks .deeplearning .OpenCLCodeGenerator .DslType .{DslBuffer , DslDouble , DslFloat , DslInt }
8
11
import com .thoughtworks .deeplearning .OpenCLCodeGenerator ._
9
12
import com .thoughtworks .each .Monadic ._
10
- import com .thoughtworks .raii .RAIITask
13
+ import com .thoughtworks .raii .ResourceFactoryT .ResourceT
14
+ import com .thoughtworks .raii .{RAIITask , ResourceFactoryT }
11
15
import shapeless .labelled ._
12
16
import shapeless ._
13
17
14
18
import scala .concurrent .ExecutionContext
15
19
import scala .util .control .NonFatal
16
- import scalaz .{@@ , Monad , Monoid }
20
+ import scalaz .{@@ , Monad , Monoid , \ / , \ /- }
17
21
import scalaz .Tags .{Multiplication , Parallel }
18
22
import scalaz .concurrent .Future
19
23
import scalaz .concurrent .Future .{ParallelFuture , futureParallelApplicativeInstance }
@@ -24,13 +28,15 @@ import scala.language.higherKinds
24
28
25
29
object DifferentiableKernel {
26
30
31
+ final case class PendingBuffer [Element ](buffer : OpenCL .Buffer [Element ], events : List [OpenCL .Event ])
32
+
27
33
private [DifferentiableKernel ] trait StaticDslTypeExtractor {
28
34
type AbstractType [A ] <: DslType
29
35
30
36
implicit def dslDouble : AbstractType [Double ]
31
37
implicit def dslFloat : AbstractType [Float ]
32
38
implicit def dslInt : AbstractType [Int ]
33
- implicit def dslBuffer [Element : AbstractType ]: AbstractType [OpenCL . Buffer [Element ]]
39
+ implicit def dslBuffer [Element : AbstractType ]: AbstractType [PendingBuffer [Element ]]
34
40
}
35
41
36
42
private [DifferentiableKernel ] trait StaticDslExpressionExtractor {
@@ -73,14 +79,14 @@ object DifferentiableKernel {
73
79
74
80
import OpenCLLayer ._
75
81
76
- def compile (context : OpenCL .Context , device : Device , commandQueue : CommandQueue )(
82
+ def compile (context : OpenCL .Context , device : Device , commandQueue : CommandQueue , semaphore : AsynchronousSemaphore )(
77
83
implicit compiler : Compiler [OutputElementData , OutputElementDelta , LocalDelta ],
78
84
outputDataMemory : Memory [OutputElementData ],
79
85
outputDeltaMemory : Memory [OutputElementDelta ],
80
86
outputDataType : StaticDslType [OutputElementData ],
81
87
outputDeltaType : StaticDslType [OutputElementDelta ],
82
88
executor : ExecutionContext ): RAIITask [(Int , compiler.ParameterRecord ) => RAIITask [
83
- Tape .Aux [OpenCL . Buffer [OutputElementData ], OpenCL . Buffer [OutputElementDelta ]]]] = throwableMonadic[RAIITask ] {
89
+ Tape .Aux [PendingBuffer [OutputElementData ], PendingBuffer [OutputElementDelta ]]]] = throwableMonadic[RAIITask ] {
84
90
85
91
RAIITask .jump().each
86
92
@@ -101,20 +107,32 @@ object DifferentiableKernel {
101
107
{ (expectedSize : Int , inputParameterMap : compiler.ParameterRecord ) =>
102
108
throwableMonadic[RAIITask ] {
103
109
val kernel = forwardKernelTask.each
104
- val outputBuffer =
105
- RAIITask .managed(context.createBuffer[ OutputElementData ](expectedSize)(outputDataMemory)).each
110
+ val outputBuffer = context.createBuffer[ OutputElementData ](expectedSize)(outputDataMemory)
111
+
106
112
compiler.setKernelInputArguments(kernel, 1 , inputParameterMap)
107
113
kernel.setArg(0 , outputBuffer)
108
- val event =
114
+
115
+ RAIITask .unmanaged(semaphore.acquire()).each
116
+ val event = try {
109
117
RAIITask
110
118
.managed(
111
119
commandQueue.enqueueNDRangeKernel(kernel, Seq (GlobalWorkSizeOnlyDimension (Address (expectedSize)))))
112
120
.each
121
+ } catch {
122
+ case e if NonFatal (e) =>
123
+ semaphore.release().run
124
+ (throw e): OpenCL .Event
125
+ }
126
+ event.waitForComplete().unsafePerformAsync { _ =>
127
+ semaphore.release().run
128
+ }
129
+
113
130
RAIITask .unmanaged(event.waitForComplete()).each
114
131
new Tape {
115
- override def data : OpenCL .Buffer [OutputElementData ] = outputBuffer
132
+ // borrow
133
+ override val data : PendingBuffer [OutputElementData ] = PendingBuffer (outputBuffer, List (event))
116
134
117
- override def backward [OutputDeltaBuffer <: OpenCL . Buffer [OutputElementDelta ]](
135
+ override def backward [OutputDeltaBuffer <: PendingBuffer [OutputElementDelta ]](
118
136
outputDeltaTask : RAIITask [OutputDeltaBuffer ]): Future [Unit ] = {
119
137
Future .suspend {
120
138
Future .now(()) // TODO: backward
@@ -123,9 +141,9 @@ object DifferentiableKernel {
123
141
}
124
142
125
143
// TODO: Change OutputData and OutputDelta to a pair of OpenCL.Buffer and OpenCL.Event
126
- override type Data = OpenCL . Buffer [OutputElementData ]
127
- override type Delta = OpenCL . Buffer [OutputElementDelta ]
128
- }: Tape .Aux [OpenCL . Buffer [OutputElementData ], OpenCL . Buffer [OutputElementDelta ]]
144
+ override type Data = PendingBuffer [OutputElementData ]
145
+ override type Delta = PendingBuffer [OutputElementDelta ]
146
+ }: Tape .Aux [PendingBuffer [OutputElementData ], PendingBuffer [OutputElementDelta ]]
129
147
}
130
148
}
131
149
}
@@ -291,10 +309,10 @@ object DifferentiableKernel {
291
309
}
292
310
293
311
def bufferIdentifier [Data , Delta ](
294
- key : Witness ): OpenCLLayer [OpenCL . Buffer [Data ],
295
- OpenCL . Buffer [Delta ],
312
+ key : Witness ): OpenCLLayer [PendingBuffer [Data ],
313
+ PendingBuffer [Delta ],
296
314
FieldType [key.T , JacobianMatrix [Data , Delta ]] :: HNil ] = {
297
- OpenCLLayer [OpenCL . Buffer [Data ], OpenCL . Buffer [Delta ], FieldType [key.T , JacobianMatrix [Data , Delta ]] :: HNil ](
315
+ OpenCLLayer [PendingBuffer [Data ], PendingBuffer [Delta ], FieldType [key.T , JacobianMatrix [Data , Delta ]] :: HNil ](
298
316
StaticDslExpression (DslExpression .Identifier (key.value)),
299
317
field[key.T ](JacobianMatrix .Identity [Data , Delta ]()) :: HNil
300
318
)
@@ -314,7 +332,7 @@ object DifferentiableKernel {
314
332
IndexLocalDelta <: HList ,
315
333
ElementLocalDelta <: HList ,
316
334
LocalDelta <: HList ](
317
- buffer : OpenCLLayer [OpenCL . Buffer [ElementData ], OpenCL . Buffer [ElementDelta ], BufferLocalDelta ],
335
+ buffer : OpenCLLayer [PendingBuffer [ElementData ], PendingBuffer [ElementDelta ], BufferLocalDelta ],
318
336
index : OpenCLLayer [Int , Float , IndexLocalDelta ])(
319
337
implicit elementDataType : StaticDslType [ElementData ],
320
338
zero : Zero .Aux [IndexLocalDelta ],
@@ -355,6 +373,8 @@ object DifferentiableKernel {
355
373
def forwardParameter : Parameter
356
374
357
375
def setArgument (kernel : Kernel , index : Int , input : Input ): Unit
376
+
377
+ def borrowEvents (input : Input ): List [OpenCL .Event ]
358
378
}
359
379
360
380
object InputCompiler {
@@ -368,14 +388,19 @@ object DifferentiableKernel {
368
388
elementDataType : StaticDslType [InputElementData ])
369
389
: InputCompiler .Aux [Key ,
370
390
JacobianMatrix .Row [InputElementData , InputElementDelta ],
371
- Tape .Aux [OpenCL . Buffer [InputElementData ], OpenCL . Buffer [InputElementDelta ]]] =
391
+ Tape .Aux [PendingBuffer [InputElementData ], PendingBuffer [InputElementDelta ]]] =
372
392
new InputCompiler [Key , JacobianMatrix .Row [InputElementData , InputElementDelta ]] {
373
393
374
- override type Input = Tape .Aux [OpenCL . Buffer [InputElementData ], OpenCL . Buffer [InputElementDelta ]]
394
+ override type Input = Tape .Aux [PendingBuffer [InputElementData ], PendingBuffer [InputElementDelta ]]
375
395
override def forwardParameter : Parameter = Parameter (witness.value, DslType .DslBuffer (elementDataType))
376
396
377
397
override def setArgument (kernel : Kernel , index : Int , input : Input ): Unit = {
378
- kernel.setArg[OpenCL .Buffer [InputElementData ]](index, input.data)
398
+ kernel.setArg[OpenCL .Buffer [InputElementData ]](index, input.data.buffer)
399
+ }
400
+
401
+ override def borrowEvents (
402
+ input : Tape .Aux [PendingBuffer [InputElementData ], PendingBuffer [InputElementDelta ]]): List [OpenCL .Event ] = {
403
+ input.data.events
379
404
}
380
405
}
381
406
@@ -395,6 +420,8 @@ object DifferentiableKernel {
395
420
Parameter (OutputId , DslType .DslBuffer (outputDataType)) :: forwardInputParameters
396
421
397
422
def setKernelInputArguments (kernel : Kernel , startIndex : Int , parameters : ParameterRecord )
423
+
424
+ def borrowEvents (parameters : ParameterRecord ): List [OpenCL .Event ]
398
425
}
399
426
400
427
object Compiler {
@@ -415,6 +442,8 @@ object DifferentiableKernel {
415
442
override def forwardInputParameters : Nil .type = Nil
416
443
417
444
override def setKernelInputArguments (kernel : Kernel , startIndex : Int , parameters : HNil ): Unit = {}
445
+
446
+ override def borrowEvents (parameters : HNil ): List [OpenCL .Event ] = Nil
418
447
}
419
448
420
449
implicit def hconsFill [OutputElementData ,
@@ -441,6 +470,10 @@ object DifferentiableKernel {
441
470
headInputCompiler.setArgument(kernel, startIndex, parameters.head)
442
471
tailCompiler.setKernelInputArguments(kernel, startIndex + 1 , parameters.tail)
443
472
}
473
+
474
+ override def borrowEvents (parameters : :: [FieldType [Key , Input ], TailParameterRecord ]): List [OpenCL .Event ] = {
475
+ headInputCompiler.borrowEvents(parameters.head) ::: tailCompiler.borrowEvents(parameters.tail)
476
+ }
444
477
}
445
478
446
479
}
0 commit comments