Skip to content

Add client roots addition/removal API and listRoots handler #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/kotlin-sdk.api
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/WithMeta$Companion {
public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol {
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V
public synthetic fun <init> (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun addRoot (Ljava/lang/String;Ljava/lang/String;)V
public final fun addRoots (Ljava/util/List;)V
protected final fun assertCapability (Ljava/lang/String;Ljava/lang/String;)V
protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
Expand Down Expand Up @@ -2715,6 +2717,8 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp
public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public final fun readResource (Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun readResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public final fun removeRoot (Ljava/lang/String;)Z
public final fun removeRoots (Ljava/util/List;)I
public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.modelcontextprotocol.kotlin.sdk.client

import io.github.oshai.kotlinlogging.KotlinLogging
import io.modelcontextprotocol.kotlin.sdk.CallToolRequest
import io.modelcontextprotocol.kotlin.sdk.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase
Expand All @@ -21,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest
import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult
import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest
import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
Expand All @@ -29,6 +31,7 @@ import io.modelcontextprotocol.kotlin.sdk.Method
import io.modelcontextprotocol.kotlin.sdk.PingRequest
import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult
import io.modelcontextprotocol.kotlin.sdk.Root
import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
Expand All @@ -44,6 +47,8 @@ import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlin.coroutines.cancellation.CancellationException

private val logger = KotlinLogging.logger {}

/**
* Options for configuring the MCP client.
*
Expand Down Expand Up @@ -89,6 +94,19 @@ public open class Client(

private val capabilities: ClientCapabilities = options.capabilities

private val roots = mutableMapOf<String, Root>()

init {
logger.debug { "Initializing MCP client with capabilities: $capabilities" }

// Internal handlers for roots
if (capabilities.roots != null) {
setRequestHandler<ListToolsRequest>(Method.Defined.RootsList) { _, _ ->
handleListRoots()
}
}
}

protected fun assertCapability(capability: String, method: String) {
val caps = serverCapabilities
val hasCapability = when (capability) {
Expand Down Expand Up @@ -449,6 +467,97 @@ public open class Client(
return request<ListToolsResult>(request, options)
}

/**
* Registers a single root.
*
* @param uri The URI of the root.
* @param name A human-readable name for the root.
* @throws IllegalStateException If the client does not support roots.
*/
public fun addRoot(
uri: String,
name: String,
) {
if (capabilities.roots == null) {
logger.error { "Failed to add root '$name': Client does not support roots capability" }
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Adding root: $name ($uri)" }
roots[uri] = Root(uri, name)
}

/**
* Registers multiple roots at once.
*
* @param rootsToAdd A list of [Root] objects to register.
* @throws IllegalStateException If the client does not support roots.
*/
public fun addRoots(rootsToAdd: List<Root>) {
if (capabilities.roots == null) {
logger.error { "Failed to add roots: Client does not support roots capability" }
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Adding ${rootsToAdd.size} roots" }
for (r in rootsToAdd) {
logger.info { "Adding root: ${r.name} (${r.uri})" }
roots[r.uri] = r
}
}

/**
* Removes a single root by URI.
*
* @param uri The URI of the root to remove.
* @return True if the root was removed, false if it wasn't found.
* @throws IllegalStateException If the client does not support roots.
*/
public fun removeRoot(uri: String): Boolean {
if (capabilities.roots == null) {
logger.error { "Failed to remove root '$uri': Client does not support roots capability" }
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Removing root: $uri" }
val removed = roots.remove(uri) != null
logger.debug {
if (removed) {
"Root removed: $uri"
} else {
"Root not found: $uri"
}
}
return removed
}

/**
* Removes multiple roots at once.
*
* @param uris A list of root URIs to remove.
* @return The number of roots that were successfully removed.
* @throws IllegalStateException If the client does not support roots.
*/
public fun removeRoots(uris: List<String>): Int {
if (capabilities.roots == null) {
logger.error { "Failed to remove roots: Client does not support roots capability" }
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Removing ${uris.size} roots" }
var removedCount = 0
for (uri in uris) {
logger.debug { "Removing root: $uri" }
if (roots.remove(uri) != null) {
removedCount++
}
}
logger.info {
if (removedCount > 0) {
"Removed $removedCount roots"
} else {
"No roots were removed"
}
}
return removedCount
}

/**
* Notifies the server that the list of roots has changed.
* Typically used if the client is managing some form of hierarchical structure.
Expand All @@ -458,4 +567,11 @@ public open class Client(
public suspend fun sendRootsListChanged() {
notification(RootsListChangedNotification())
}

// --- Internal Handlers ---

private suspend fun handleListRoots(): ListRootsResult {
val rootList = roots.values.toList()
return ListRootsResult(rootList)
}
}
181 changes: 174 additions & 7 deletions src/jvmTest/kotlin/client/ClientTest.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package client

import io.mockk.coEvery
import io.mockk.spyk
import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest
import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
import io.modelcontextprotocol.kotlin.sdk.Implementation
import io.mockk.coEvery
import io.mockk.spyk
import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
Expand All @@ -23,10 +23,17 @@ import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.Method
import io.modelcontextprotocol.kotlin.sdk.Role
import io.modelcontextprotocol.kotlin.sdk.Root
import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.TextContent
import io.modelcontextprotocol.kotlin.sdk.Tool
import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancel
Expand All @@ -35,13 +42,9 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
import org.junit.jupiter.api.Test
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import org.junit.jupiter.api.assertInstanceOf
import org.junit.jupiter.api.assertThrows
import kotlin.coroutines.cancellation.CancellationException
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
Expand Down Expand Up @@ -628,4 +631,168 @@ class ClientTest {
assertEquals(null, receivedAsResponse.error)
}

@Test
fun `listRoots returns list of roots`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
roots = ClientCapabilities.Roots(null)
)
)
)

val clientRoots = listOf(
Root(uri = "file:///test-root", name = "testRoot")
)

client.addRoots(clientRoots)

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val server = Server(
serverInfo = Implementation(name = "test server", version = "1.0"),
options = ServerOptions(
capabilities = ServerCapabilities()
)
)

listOf(
launch { client.connect(clientTransport) },
launch { server.connect(serverTransport) }
).joinAll()

val clientCapabilities = server.clientCapabilities
assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots)

val listRootsResult = server.listRoots()

assertEquals(listRootsResult.roots, clientRoots)
}

@Test
fun `addRoot should throw when roots capability is not supported`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities()
)
)

// Verify that adding a root throws an exception
val exception = assertThrows<IllegalStateException> {
client.addRoot(uri = "file:///test-root1", name = "testRoot1")
}
assertEquals("Client does not support roots capability.", exception.message)
}

@Test
fun `removeRoot should throw when roots capability is not supported`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities()
)
)

// Verify that removing a root throws an exception
val exception = assertThrows<IllegalStateException> {
client.removeRoot(uri = "file:///test-root1")
}
assertEquals("Client does not support roots capability.", exception.message)
}

@Test
fun `removeRoot should remove a root`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
roots = ClientCapabilities.Roots(null)
)
)
)

// Add some roots
client.addRoots(
listOf(
Root(uri = "file:///test-root1", name = "testRoot1"),
Root(uri = "file:///test-root2", name = "testRoot2"),
)
)

// Remove a root
val result = client.removeRoot("file:///test-root1")

// Verify the root was removed
assertTrue(result, "Root should be removed successfully")
}

@Test
fun `removeRoots should remove multiple roots`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
roots = ClientCapabilities.Roots(null)
)
)
)

// Add some roots
client.addRoots(
listOf(
Root(uri = "file:///test-root1", name = "testRoot1"),
Root(uri = "file:///test-root2", name = "testRoot2"),
)
)

// Remove multiple roots
val result = client.removeRoots(
listOf("file:///test-root1", "file:///test-root2")
)

// Verify the root was removed
assertEquals(2, result, "Both roots should be removed")
}

@Test
fun `sendRootsListChanged should notify server`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
roots = ClientCapabilities.Roots(listChanged = true)
)
)
)

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val server = Server(
serverInfo = Implementation(name = "test server", version = "1.0"),
options = ServerOptions(
capabilities = ServerCapabilities()
)
)

// Track notifications
var rootListChangedNotificationReceived = false
server.setNotificationHandler<RootsListChangedNotification>(Method.Defined.NotificationsRootsListChanged) {
rootListChangedNotificationReceived = true
CompletableDeferred(Unit)
}

listOf(
launch { client.connect(clientTransport) },
launch { server.connect(serverTransport) }
).joinAll()

client.sendRootsListChanged()

assertTrue(
rootListChangedNotificationReceived,
"Notification should be sent when sendRootsListChanged is called"
)
}
}