Skip to content

Commit 4782275

Browse files
committed
Extract CommManager to a separate interface
1 parent e64f0ac commit 4782275

File tree

15 files changed

+290
-231
lines changed

15 files changed

+290
-231
lines changed

jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/Notebook.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.jetbrains.kotlinx.jupyter.api
22

3+
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
34
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
45
import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest
56

@@ -102,4 +103,6 @@ interface Notebook {
102103
val libraryRequests: Collection<LibraryResolutionRequest>
103104

104105
val connection: JupyterConnection
106+
107+
val commManager: CommManager
105108
}

jupyter-lib/api/src/main/kotlin/org/jetbrains/kotlinx/jupyter/api/libraries/connection.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ interface JupyterConnection {
8282
* Simpler-to-use version of [send].
8383
*/
8484
fun sendReply(socketName: JupyterSocket, parentMessage: RawMessage, type: String, content: JsonObject, metadata: JsonObject? = null)
85+
}
8586

87+
interface CommManager {
8688
/**
8789
* Creates a comm with a given target, generates unique ID for it. Sends comm_open request to frontend
8890
*

src/main/kotlin/org/jetbrains/kotlinx/jupyter/apiImpl.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import org.jetbrains.kotlinx.jupyter.api.Notebook
1212
import org.jetbrains.kotlinx.jupyter.api.RenderersProcessor
1313
import org.jetbrains.kotlinx.jupyter.api.ResultsAccessor
1414
import org.jetbrains.kotlinx.jupyter.api.VariableState
15+
import org.jetbrains.kotlinx.jupyter.api.libraries.CommManager
1516
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
1617
import org.jetbrains.kotlinx.jupyter.api.libraries.LibraryResolutionRequest
1718
import org.jetbrains.kotlinx.jupyter.repl.impl.SharedReplContext
@@ -136,7 +137,8 @@ class EvalData(
136137

137138
class NotebookImpl(
138139
private val runtimeProperties: ReplRuntimeProperties,
139-
override val connection: JupyterConnection
140+
override val connection: JupyterConnection,
141+
override val commManager: CommManager,
140142
) : MutableNotebook {
141143
private val cells = hashMapOf<Int, MutableCodeCell>()
142144
override var sharedReplContext: SharedReplContext? = null

src/main/kotlin/org/jetbrains/kotlinx/jupyter/connection.kt

Lines changed: 23 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,26 @@ import kotlinx.serialization.json.JsonElement
1010
import kotlinx.serialization.json.JsonNull
1111
import kotlinx.serialization.json.JsonObject
1212
import kotlinx.serialization.json.jsonObject
13-
import org.jetbrains.kotlinx.jupyter.api.libraries.Comm
14-
import org.jetbrains.kotlinx.jupyter.api.libraries.CommCloseCallback
15-
import org.jetbrains.kotlinx.jupyter.api.libraries.CommMsgCallback
16-
import org.jetbrains.kotlinx.jupyter.api.libraries.CommOpenCallback
17-
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterConnection
1813
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterSocket
1914
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessage
2015
import org.jetbrains.kotlinx.jupyter.api.libraries.RawMessageCallback
2116
import org.jetbrains.kotlinx.jupyter.api.libraries.header
2217
import org.jetbrains.kotlinx.jupyter.api.libraries.type
2318
import org.jetbrains.kotlinx.jupyter.exceptions.ReplException
24-
import org.jetbrains.kotlinx.jupyter.messaging.CommClose
25-
import org.jetbrains.kotlinx.jupyter.messaging.CommMsg
26-
import org.jetbrains.kotlinx.jupyter.messaging.CommOpen
2719
import org.jetbrains.kotlinx.jupyter.messaging.InputReply
2820
import org.jetbrains.kotlinx.jupyter.messaging.InputRequest
29-
import org.jetbrains.kotlinx.jupyter.messaging.JupyterOutType
21+
import org.jetbrains.kotlinx.jupyter.messaging.JupyterConnectionInternal
22+
import org.jetbrains.kotlinx.jupyter.messaging.JupyterServerSocket
3023
import org.jetbrains.kotlinx.jupyter.messaging.KernelStatus
3124
import org.jetbrains.kotlinx.jupyter.messaging.Message
32-
import org.jetbrains.kotlinx.jupyter.messaging.MessageContent
33-
import org.jetbrains.kotlinx.jupyter.messaging.MessageData
34-
import org.jetbrains.kotlinx.jupyter.messaging.MessageHeader
3525
import org.jetbrains.kotlinx.jupyter.messaging.MessageType
3626
import org.jetbrains.kotlinx.jupyter.messaging.RawMessageImpl
3727
import org.jetbrains.kotlinx.jupyter.messaging.StatusReply
38-
import org.jetbrains.kotlinx.jupyter.messaging.StreamResponse
3928
import org.jetbrains.kotlinx.jupyter.messaging.emptyJsonObjectStringBytes
4029
import org.jetbrains.kotlinx.jupyter.messaging.jsonObject
41-
import org.jetbrains.kotlinx.jupyter.messaging.makeHeader
4230
import org.jetbrains.kotlinx.jupyter.messaging.makeJsonHeader
4331
import org.jetbrains.kotlinx.jupyter.messaging.makeReplyMessage
32+
import org.jetbrains.kotlinx.jupyter.messaging.sendMessage
4433
import org.jetbrains.kotlinx.jupyter.messaging.toMessage
4534
import org.jetbrains.kotlinx.jupyter.messaging.toRawMessage
4635
import org.jetbrains.kotlinx.jupyter.util.EMPTY
@@ -49,9 +38,6 @@ import org.zeromq.ZMQ
4938
import java.io.Closeable
5039
import java.io.IOException
5140
import java.security.SignatureException
52-
import java.util.UUID
53-
import java.util.concurrent.ConcurrentHashMap
54-
import java.util.concurrent.CopyOnWriteArrayList
5541
import javax.crypto.Mac
5642
import javax.crypto.spec.SecretKeySpec
5743
import kotlin.concurrent.thread
@@ -60,27 +46,20 @@ import kotlin.math.min
6046
typealias SocketMessageCallback = JupyterConnectionImpl.Socket.(Message) -> Unit
6147
typealias SocketRawMessageCallback = JupyterConnectionImpl.Socket.(RawMessage) -> Unit
6248

63-
class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Closeable {
49+
class JupyterConnectionImpl(
50+
val config: KernelConfig
51+
) : JupyterConnectionInternal, Closeable {
6452

65-
private var messageId: List<ByteArray> = listOf(byteArrayOf(1))
66-
private var sessionId = ""
67-
private var username = ""
53+
private var _messageId: List<ByteArray> = listOf(byteArrayOf(1))
54+
override val messageId: List<ByteArray> get() = _messageId
6855

69-
private fun makeDefaultHeader(msgType: MessageType): MessageHeader {
70-
return makeHeader(msgType, sessionId = sessionId, username = username)
71-
}
56+
private var _sessionId = ""
57+
override val sessionId: String get() = _sessionId
7258

73-
fun makeSimpleMessage(msgType: MessageType, content: MessageContent): Message {
74-
return Message(
75-
id = messageId,
76-
data = MessageData(
77-
header = makeDefaultHeader(msgType),
78-
content = content
79-
)
80-
)
81-
}
59+
private var _username = ""
60+
override val username: String get() = _username
8261

83-
inner class Socket(private val socket: JupyterSocketInfo, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type) {
62+
inner class Socket(private val socket: JupyterSocketInfo, type: SocketType = socket.zmqKernelType) : ZMQ.Socket(context, type), JupyterServerSocket {
8463
val name: String get() = socket.name
8564
init {
8665
val port = config.ports[socket.ordinal]
@@ -137,15 +116,7 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
137116
sendStatus(KernelStatus.IDLE, incomingMessage)
138117
}
139118

140-
fun sendOut(msg: Message, stream: JupyterOutType, text: String) {
141-
sendMessage(makeReplyMessage(msg, header = makeHeader(MessageType.STREAM, msg), content = StreamResponse(stream.optionName(), text)))
142-
}
143-
144-
fun sendMessage(msg: Message) {
145-
sendRawMessage(msg.toRawMessage())
146-
}
147-
148-
fun sendRawMessage(msg: RawMessage) {
119+
override fun sendRawMessage(msg: RawMessage) {
149120
log.debug("[$name] snd>: $msg")
150121
sendRawMessage(msg, hmac)
151122
}
@@ -166,7 +137,7 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
166137
}
167138
}
168139

169-
val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
140+
override val connection: JupyterConnectionImpl = this@JupyterConnectionImpl
170141
}
171142

172143
inner class StdinInputStream : java.io.InputStream() {
@@ -227,85 +198,14 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
227198
}
228199
}
229200

230-
inner class CommImpl(
231-
override val target: String,
232-
override val id: String
233-
) : Comm {
234-
235-
private val onMessageCallbacks = mutableListOf<CommMsgCallback>()
236-
private val onCloseCallbacks = mutableListOf<CommCloseCallback>()
237-
private var closed = false
238-
239-
private fun assertOpen() {
240-
if (closed) {
241-
throw AssertionError("Comm '$target' has been already closed")
242-
}
243-
}
244-
override fun send(data: JsonObject) {
245-
assertOpen()
246-
iopub.sendMessage(
247-
makeSimpleMessage(
248-
MessageType.COMM_MSG,
249-
CommMsg(id, data)
250-
)
251-
)
252-
}
253-
254-
override fun onMessage(action: CommMsgCallback): CommMsgCallback {
255-
assertOpen()
256-
onMessageCallbacks.add(action)
257-
return action
258-
}
259-
260-
override fun removeMessageCallback(callback: CommMsgCallback) {
261-
onMessageCallbacks.remove(callback)
262-
}
263-
264-
override fun onClose(action: CommCloseCallback): CommCloseCallback {
265-
assertOpen()
266-
onCloseCallbacks.add(action)
267-
return action
268-
}
269-
270-
override fun removeCloseCallback(callback: CommCloseCallback) {
271-
onCloseCallbacks.remove(callback)
272-
}
273-
274-
override fun close(data: JsonObject, notifyClient: Boolean) {
275-
assertOpen()
276-
closed = true
277-
onMessageCallbacks.clear()
278-
279-
removeComm(id)
280-
281-
onCloseCallbacks.forEach { it(data) }
282-
283-
if (notifyClient) {
284-
iopub.sendMessage(
285-
makeSimpleMessage(
286-
MessageType.COMM_CLOSE,
287-
CommClose(id, data)
288-
)
289-
)
290-
}
291-
}
292-
293-
fun messageReceived(data: JsonObject) {
294-
if (closed) return
295-
for (callback in onMessageCallbacks) {
296-
callback(data)
297-
}
298-
}
299-
}
300-
301201
private val hmac = HMAC(config.signatureScheme.replace("-", ""), config.signatureKey)
302202
private val context = ZMQ.context(1)
303203

304-
val heartbeat = Socket(JupyterSocketInfo.HB)
305-
val shell = Socket(JupyterSocketInfo.SHELL)
306-
val control = Socket(JupyterSocketInfo.CONTROL)
307-
val stdin = Socket(JupyterSocketInfo.STDIN)
308-
val iopub = Socket(JupyterSocketInfo.IOPUB)
204+
override val heartbeat = Socket(JupyterSocketInfo.HB)
205+
override val shell = Socket(JupyterSocketInfo.SHELL)
206+
override val control = Socket(JupyterSocketInfo.CONTROL)
207+
override val stdin = Socket(JupyterSocketInfo.STDIN)
208+
override val iopub = Socket(JupyterSocketInfo.IOPUB)
309209

310210
private fun fromSocketName(socket: JupyterSocket): Socket {
311211
return when (socket) {
@@ -338,9 +238,9 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
338238

339239
fun updateSessionInfo(message: Message) {
340240
val header = message.data.header ?: return
341-
header.session?.let { sessionId = it }
342-
header.username?.let { username = it }
343-
messageId = message.id
241+
header.session?.let { _sessionId = it }
242+
header.username?.let { _username = it }
243+
_messageId = message.id
344244
}
345245

346246
override fun send(socketName: JupyterSocket, message: RawMessage) {
@@ -363,73 +263,6 @@ class JupyterConnectionImpl(val config: KernelConfig) : JupyterConnection, Close
363263
send(socketName, message)
364264
}
365265

366-
private val commOpenCallbacks = ConcurrentHashMap<String, CommOpenCallback>()
367-
private val commTargetToIds = ConcurrentHashMap<String, CopyOnWriteArrayList<String>>()
368-
private val commIdToComm = ConcurrentHashMap<String, CommImpl>()
369-
override fun openComm(target: String, data: JsonObject): Comm {
370-
val id = UUID.randomUUID().toString()
371-
val newComm = processCommOpen(target, id, data)
372-
373-
// send comm_open
374-
iopub.sendMessage(
375-
makeSimpleMessage(
376-
MessageType.COMM_OPEN,
377-
CommOpen(newComm.id, newComm.target)
378-
)
379-
)
380-
381-
return newComm
382-
}
383-
384-
fun processCommOpen(target: String, id: String, data: JsonObject): Comm {
385-
val commIds = commTargetToIds.getOrPut(target) { CopyOnWriteArrayList() }
386-
val newComm = CommImpl(target, id)
387-
commIds.add(id)
388-
commIdToComm[id] = newComm
389-
390-
val callback = commOpenCallbacks[target]
391-
callback?.invoke(newComm, data)
392-
393-
return newComm
394-
}
395-
396-
override fun closeComm(id: String, data: JsonObject) {
397-
val comm = commIdToComm[id] ?: return
398-
comm.close(data, notifyClient = true)
399-
}
400-
401-
fun processCommClose(id: String, data: JsonObject) {
402-
val comm = commIdToComm[id] ?: return
403-
comm.close(data, notifyClient = false)
404-
}
405-
406-
fun removeComm(id: String) {
407-
val comm = commIdToComm[id] ?: return
408-
val commIds = commTargetToIds[comm.target]!!
409-
commIds.remove(id)
410-
commIdToComm.remove(id)
411-
}
412-
413-
override fun getComms(target: String?): Collection<Comm> {
414-
return if (target == null) {
415-
commIdToComm.values.toList()
416-
} else {
417-
commTargetToIds[target].orEmpty().mapNotNull { commIdToComm[it] }
418-
}
419-
}
420-
421-
fun processCommMessage(id: String, data: JsonObject) {
422-
commIdToComm[id]?.messageReceived(data)
423-
}
424-
425-
override fun registerCommTarget(target: String, callback: (Comm, JsonObject) -> Unit) {
426-
commOpenCallbacks[target] = callback
427-
}
428-
429-
override fun unregisterCommTarget(target: String) {
430-
commOpenCallbacks.remove(target)
431-
}
432-
433266
val stdinIn = StdinInputStream()
434267

435268
var contextMessage: Message? = null

src/main/kotlin/org/jetbrains/kotlinx/jupyter/ikotlin.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider
44
import org.jetbrains.kotlinx.jupyter.libraries.KERNEL_LIBRARIES
55
import org.jetbrains.kotlinx.jupyter.libraries.ResolutionInfoProvider
66
import org.jetbrains.kotlinx.jupyter.libraries.getDefaultDirectoryResolutionInfoProvider
7+
import org.jetbrains.kotlinx.jupyter.messaging.CommManagerImpl
78
import org.jetbrains.kotlinx.jupyter.messaging.controlMessagesHandler
89
import org.jetbrains.kotlinx.jupyter.messaging.shellMessagesHandler
910
import org.jetbrains.kotlinx.jupyter.repl.creating.DefaultReplFactory
@@ -122,7 +123,8 @@ fun kernelServer(config: KernelConfig, runtimeProperties: ReplRuntimeProperties
122123

123124
val executionCount = AtomicLong(1)
124125

125-
val repl = DefaultReplFactory(config, runtimeProperties, scriptReceivers, conn).createRepl()
126+
val commManager = CommManagerImpl(conn)
127+
val repl = DefaultReplFactory(config, runtimeProperties, scriptReceivers, conn, commManager).createRepl()
126128

127129
val mainThread = Thread.currentThread()
128130

@@ -148,7 +150,7 @@ fun kernelServer(config: KernelConfig, runtimeProperties: ReplRuntimeProperties
148150

149151
conn.shell.onMessage { message ->
150152
conn.updateSessionInfo(message)
151-
shellMessagesHandler(message, repl, executionCount)
153+
shellMessagesHandler(message, repl, commManager, executionCount)
152154
}
153155

154156
val controlThread = thread {

0 commit comments

Comments
 (0)