Skip to content

Commit b2028d6

Browse files
committed
Refactor execution model
1 parent 4782275 commit b2028d6

File tree

9 files changed

+260
-155
lines changed

9 files changed

+260
-155
lines changed

src/main/kotlin/org/jetbrains/kotlinx/jupyter/config.kt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import kotlinx.serialization.json.JsonPrimitive
1313
import kotlinx.serialization.json.decodeFromJsonElement
1414
import kotlinx.serialization.serializer
1515
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelVersion
16+
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket
1617
import org.jetbrains.kotlinx.jupyter.common.getNameForUser
1718
import org.jetbrains.kotlinx.jupyter.config.getLogger
1819
import org.jetbrains.kotlinx.jupyter.config.readResourceAsIniFile
@@ -36,12 +37,12 @@ val defaultRuntimeProperties by lazy {
3637
RuntimeKernelProperties(readResourceAsIniFile("runtime.properties"))
3738
}
3839

39-
enum class JupyterSocketInfo(val zmqKernelType: SocketType, val zmqClientType: SocketType) {
40-
HB(SocketType.REP, SocketType.REQ),
41-
SHELL(SocketType.ROUTER, SocketType.REQ),
42-
CONTROL(SocketType.ROUTER, SocketType.REQ),
43-
STDIN(SocketType.ROUTER, SocketType.REQ),
44-
IOPUB(SocketType.PUB, SocketType.SUB);
40+
enum class JupyterSocketInfo(val type: JupyterSocket, val zmqKernelType: SocketType, val zmqClientType: SocketType) {
41+
HB(JupyterSocket.HB, SocketType.REP, SocketType.REQ),
42+
SHELL(JupyterSocket.SHELL, SocketType.ROUTER, SocketType.REQ),
43+
CONTROL(JupyterSocket.CONTROL, SocketType.ROUTER, SocketType.REQ),
44+
STDIN(JupyterSocket.STDIN, SocketType.ROUTER, SocketType.REQ),
45+
IOPUB(JupyterSocket.IOPUB, SocketType.PUB, SocketType.SUB);
4546

4647
val nameForUser = getNameForUser(name)
4748
}

src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt

Lines changed: 20 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
package org.jetbrains.kotlinx.jupyter
22

3-
import kotlinx.coroutines.CoroutineScope
4-
import kotlinx.coroutines.Dispatchers
5-
import kotlinx.coroutines.launch
63
import kotlinx.serialization.decodeFromString
74
import kotlinx.serialization.encodeToString
85
import kotlinx.serialization.json.Json
@@ -15,7 +12,6 @@ import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
1512
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback
1613
import org.jetbrains.kotlinx.jupyter.api.libraries.header
1714
import org.jetbrains.kotlinx.jupyter.api.libraries.type
18-
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
1915
import org.jetbrains.kotlinx.jupyter.messaging.InputReply
2016
import org.jetbrains.kotlinx.jupyter.messaging.InputRequest
2117
import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal
@@ -29,6 +25,7 @@ import org.jetbrains.kotlinx.jupyter.messaging.emptyJsonObjectStringBytes
2925
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
3026
import org.jetbrains.kotlinx.jupyter.messaging.makeJsonHeader
3127
import org.jetbrains.kotlinx.jupyter.messaging.makeReplyMessage
28+
import org.jetbrains.kotlinx.jupyter.messaging.makeSimpleMessage
3229
import org.jetbrains.kotlinx.jupyter.messaging.sendMessage
3330
import org.jetbrains.kotlinx.jupyter.messaging.toMessage
3431
import org.jetbrains.kotlinx.jupyter.messaging.toRawMessage
@@ -37,10 +34,10 @@ import org.zeromq.SocketType
3734
import org.zeromq.ZMQ
3835
import java.io.Closeable
3936
import java.io.IOException
37+
import java.io.InputStream
4038
import java.security.SignatureException
4139
import javax.crypto.Mac
4240
import javax.crypto.spec.SecretKeySpec
43-
import kotlin.concurrent.thread
4441
import kotlin.math.min
4542

4643
typealias SocketMessageCallback = JupyterConnectionImpl.Socket.(Message) -> Unit
@@ -106,14 +103,8 @@ class JupyterConnectionImpl(
106103
}
107104
}
108105

109-
fun sendStatus(status: KernelStatus, msg: Message) {
110-
connection.iopub.sendMessage(makeReplyMessage(msg, MessageType.STATUS, content = StatusReply(status)))
111-
}
112-
113106
fun sendWrapped(incomingMessage: Message, msg: Message) {
114-
sendStatus(KernelStatus.BUSY, incomingMessage)
115-
sendMessage(msg)
116-
sendStatus(KernelStatus.IDLE, incomingMessage)
107+
doWrappedInBusyIdle(incomingMessage) { sendMessage(msg) }
117108
}
118109

119110
override fun sendRawMessage(msg: RawMessage) {
@@ -140,7 +131,7 @@ class JupyterConnectionImpl(
140131
override val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
141132
}
142133

143-
inner class StdinInputStream : java.io.InputStream() {
134+
inner class StdinInputStream : InputStream() {
144135
private var currentBuf: ByteArray? = null
145136
private var currentBufPos = 0
146137

@@ -263,57 +254,26 @@ class JupyterConnectionImpl(
263254
send(socketName, message)
264255
}
265256

266-
val stdinIn = StdinInputStream()
267-
268-
var contextMessage: Message? = null
269-
270-
private val currentExecutions = HashSet<Thread>()
271-
private val coroutineScope = CoroutineScope(Dispatchers.Default)
272-
273-
data class ConnectionExecutionResult<T>(
274-
val result: T?,
275-
val throwable: Throwable?,
276-
val isInterrupted: Boolean,
277-
)
278-
279-
fun <T> runExecution(body: () -> T, classLoader: ClassLoader): ConnectionExecutionResult<T> {
280-
var execRes: T? = null
281-
var execException: Throwable? = null
282-
val execThread = thread(contextClassLoader = classLoader) {
283-
try {
284-
execRes = body()
285-
} catch (e: Throwable) {
286-
execException = e
287-
}
288-
}
289-
currentExecutions.add(execThread)
290-
execThread.join()
291-
currentExecutions.remove(execThread)
292-
293-
val exception = execException
294-
val isInterrupted = exception is ThreadDeath ||
295-
(exception is ReplException && exception.cause is ThreadDeath)
296-
return ConnectionExecutionResult(execRes, exception, isInterrupted)
257+
override fun sendStatus(status: KernelStatus, incomingMessage: Message?) {
258+
val message = if (incomingMessage != null) makeReplyMessage(incomingMessage, MessageType.STATUS, content = StatusReply(status))
259+
else makeSimpleMessage(MessageType.STATUS, content = StatusReply(status))
260+
iopub.sendMessage(message)
297261
}
298262

299-
/**
300-
* We cannot use [Thread.interrupt] here because we have no way
301-
* to control the code user executes. [Thread.interrupt] will do nothing for
302-
* the simple calculation (like `while (true) 1`). Consider replacing with
303-
* something more smart in the future.
304-
*/
305-
fun interruptExecution() {
306-
@Suppress("deprecation")
307-
while (currentExecutions.isNotEmpty()) {
308-
val execution = currentExecutions.firstOrNull()
309-
execution?.stop()
310-
currentExecutions.remove(execution)
263+
override fun doWrappedInBusyIdle(incomingMessage: Message?, action: () -> Unit) {
264+
sendStatus(KernelStatus.BUSY, incomingMessage)
265+
try {
266+
action()
267+
} finally {
268+
sendStatus(KernelStatus.IDLE, incomingMessage)
311269
}
312270
}
313271

314-
fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
315-
coroutineScope.launch(block = runnable)
316-
}
272+
override val stdinIn = StdinInputStream()
273+
274+
var contextMessage: Message? = null
275+
276+
override val executor: JupyterExecutor = JupyterExecutorImpl()
317277

318278
override fun close() {
319279
heartbeat.close()
@@ -399,7 +359,7 @@ fun ZMQ.Socket.receiveRawMessage(start: ByteArray, hmac: HMAC): RawMessage {
399359
)
400360
}
401361

402-
object DisabledStdinInputStream : java.io.InputStream() {
362+
object DisabledStdinInputStream : InputStream() {
403363
override fun read(): Int {
404364
throw IOException("Input from stdin is unsupported by the client")
405365
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package org.jetbrains.kotlinx.jupyter
2+
3+
import kotlinx.coroutines.CoroutineScope
4+
import kotlinx.coroutines.Dispatchers
5+
import kotlinx.coroutines.launch
6+
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
7+
import java.util.Collections
8+
import java.util.concurrent.ConcurrentHashMap
9+
import kotlin.concurrent.thread
10+
11+
sealed interface ExecutionResult<out T> {
12+
class Success<out T>(val result: T) : ExecutionResult<T>
13+
class Failure(val throwable: Throwable) : ExecutionResult<Nothing>
14+
object Interrupted : ExecutionResult<Nothing>
15+
}
16+
17+
interface JupyterExecutor {
18+
fun <T> runExecution(classLoader: ClassLoader? = null, body: () -> T): ExecutionResult<T>
19+
fun interruptExecutions()
20+
21+
fun launchJob(runnable: suspend CoroutineScope.() -> Unit)
22+
}
23+
24+
class JupyterExecutorImpl : JupyterExecutor {
25+
private val currentExecutions: MutableSet<Thread> = Collections.newSetFromMap(ConcurrentHashMap())
26+
private val coroutineScope = CoroutineScope(Dispatchers.Default)
27+
28+
override fun <T> runExecution(classLoader: ClassLoader?, body: () -> T): ExecutionResult<T> {
29+
var execRes: T? = null
30+
var execException: Throwable? = null
31+
val execThread = thread(contextClassLoader = classLoader ?: Thread.currentThread().contextClassLoader) {
32+
try {
33+
execRes = body()
34+
} catch (e: Throwable) {
35+
execException = e
36+
}
37+
}
38+
currentExecutions.add(execThread)
39+
execThread.join()
40+
currentExecutions.remove(execThread)
41+
42+
val exception = execException
43+
44+
return if (exception == null) {
45+
ExecutionResult.Success(execRes!!)
46+
} else {
47+
val isInterrupted = exception is ThreadDeath ||
48+
(exception is ReplException && exception.cause is ThreadDeath)
49+
if (isInterrupted) ExecutionResult.Interrupted
50+
else ExecutionResult.Failure(exception)
51+
}
52+
}
53+
54+
/**
55+
* We cannot use [Thread.interrupt] here because we have no way
56+
* to control the code user executes. [Thread.interrupt] will do nothing for
57+
* the simple calculation (like `while (true) 1`). Consider replacing with
58+
* something more smart in the future.
59+
*/
60+
override fun interruptExecutions() {
61+
@Suppress("deprecation")
62+
while (currentExecutions.isNotEmpty()) {
63+
val execution = currentExecutions.firstOrNull()
64+
execution?.stop()
65+
currentExecutions.remove(execution)
66+
}
67+
}
68+
69+
override fun launchJob(runnable: suspend CoroutineScope.() -> Unit) {
70+
coroutineScope.launch(block = runnable)
71+
}
72+
}

src/main/kotlin/org/jetbrains/kotlinx/jupyter/messaging/CommManagerImpl.kt

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.jetbrains.kotlinx.jupyter.messaging
22

33
import kotlinx.serialization.json.JsonObject
4+
import kotlinx.serialization.json.JsonPrimitive
45
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
56
import org.jetbrains.kotlinx.jupyter.api.libraries.CommCloseCallback
67
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
@@ -11,9 +12,9 @@ import java.util.concurrent.ConcurrentHashMap
1112
import java.util.concurrent.CopyOnWriteArrayList
1213

1314
interface CommManagerInternal : CommManager {
14-
fun processCommOpen(target: String, id: String, data: JsonObject): Comm
15-
fun processCommMessage(id: String, data: JsonObject)
16-
fun processCommClose(id: String, data: JsonObject)
15+
fun processCommOpen(message: Message, content: CommOpen): Comm?
16+
fun processCommMessage(message: Message, content: CommMsg)
17+
fun processCommClose(message: Message, content: CommClose)
1718
}
1819

1920
class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommManagerInternal {
@@ -25,26 +26,51 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
2526

2627
override fun openComm(target: String, data: JsonObject): Comm {
2728
val id = UUID.randomUUID().toString()
28-
val newComm = processCommOpen(target, id, data)
29+
val newComm = registerNewComm(target, id)
2930

3031
// send comm_open
3132
iopub.sendSimpleMessage(
3233
MessageType.COMM_OPEN,
33-
CommOpen(newComm.id, newComm.target)
34+
CommOpen(newComm.id, newComm.target, data)
3435
)
3536

3637
return newComm
3738
}
3839

39-
override fun processCommOpen(target: String, id: String, data: JsonObject): Comm {
40+
override fun processCommOpen(message: Message, content: CommOpen): Comm? {
41+
val target = content.targetName
42+
val id = content.commId
43+
val data = content.data
44+
45+
val callback = commOpenCallbacks[target]
46+
if (callback == null) {
47+
// If no callback is registered, we should send `comm_close` immediately in response.
48+
iopub.sendSimpleMessage(
49+
MessageType.COMM_CLOSE,
50+
CommClose(id, commFailureJson("Target $target was not registered"))
51+
)
52+
return null
53+
}
54+
55+
val newComm = registerNewComm(target, id)
56+
try {
57+
callback(newComm, data)
58+
} catch (e: Throwable) {
59+
iopub.sendSimpleMessage(
60+
MessageType.COMM_CLOSE,
61+
CommClose(id, commFailureJson("Unable to crete comm $id (with target $target), exception was thrown: ${e.stackTraceToString()}"))
62+
)
63+
removeComm(id)
64+
}
65+
66+
return newComm
67+
}
68+
69+
private fun registerNewComm(target: String, id: String): Comm {
4070
val commIds = commTargetToIds.getOrPut(target) { CopyOnWriteArrayList() }
4171
val newComm = CommImpl(target, id)
4272
commIds.add(id)
4373
commIdToComm[id] = newComm
44-
45-
val callback = commOpenCallbacks[target]
46-
callback?.invoke(newComm, data)
47-
4874
return newComm
4975
}
5076

@@ -53,9 +79,9 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
5379
comm.close(data, notifyClient = true)
5480
}
5581

56-
override fun processCommClose(id: String, data: JsonObject) {
57-
val comm = commIdToComm[id] ?: return
58-
comm.close(data, notifyClient = false)
82+
override fun processCommClose(message: Message, content: CommClose) {
83+
val comm = commIdToComm[content.commId] ?: return
84+
comm.close(content.data, notifyClient = false)
5985
}
6086

6187
fun removeComm(id: String) {
@@ -73,8 +99,8 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
7399
}
74100
}
75101

76-
override fun processCommMessage(id: String, data: JsonObject) {
77-
commIdToComm[id]?.messageReceived(data)
102+
override fun processCommMessage(message: Message, content: CommMsg) {
103+
commIdToComm[content.commId]?.messageReceived(message, content.data)
78104
}
79105

80106
override fun registerCommTarget(target: String, callback: (Comm, JsonObject) -> Unit) {
@@ -144,11 +170,24 @@ class CommManagerImpl(private val connection: JupyterConnectionInternal) : CommM
144170
}
145171
}
146172

147-
fun messageReceived(data: JsonObject) {
173+
fun messageReceived(message: Message, data: JsonObject) {
148174
if (closed) return
149-
for (callback in onMessageCallbacks) {
150-
callback(data)
175+
176+
connection.doWrappedInBusyIdle(message) {
177+
for (callback in onMessageCallbacks) {
178+
callback(data)
179+
}
151180
}
152181
}
153182
}
183+
184+
companion object {
185+
private fun commFailureJson(errorMessage: String): JsonObject {
186+
return JsonObject(
187+
mapOf(
188+
"error" to JsonPrimitive(errorMessage)
189+
)
190+
)
191+
}
192+
}
154193
}

0 commit comments

Comments
 (0)