Skip to content

Header propagation #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions core/src/main/scala/chimp/McpHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import io.circe.*
import io.circe.syntax.*
import org.slf4j.LoggerFactory
import sttp.apispec.circe.*
import sttp.tapir.*
import sttp.tapir.docs.apispec.schema.TapirSchemaToJsonSchema
import sttp.monad.MonadError
import sttp.monad.syntax.*
import sttp.tapir.*
import sttp.tapir.docs.apispec.schema.TapirSchemaToJsonSchema

/** The MCP server handles JSON-RPC requests for tool listing, invocation, and initialization.
*
Expand Down Expand Up @@ -55,18 +55,18 @@ class McpHandler[F[_]](tools: List[ServerTool[?, F]], name: String = "Chimp MCP
/** Handles the 'tools/call' JSON-RPC method. Attempts to decode the tool name and arguments, then dispatches to the tool logic. Provides
* detailed error messages for decode failures.
*/
private def handleToolsCall(params: Option[io.circe.Json], id: RequestId)(using MonadError[F]): F[JSONRPCMessage] =
// Extract tool name and arguments in a functional, idiomatic way
private def handleToolsCall(params: Option[io.circe.Json], id: RequestId, headerValue: Option[String])(using
MonadError[F]
): F[JSONRPCMessage] =
val toolNameOpt = params.flatMap(_.hcursor.downField("name").as[String].toOption)
val argumentsOpt = params.flatMap(_.hcursor.downField("arguments").focus)
(toolNameOpt, argumentsOpt) match
case (Some(toolName), Some(args)) =>
toolsByName.get(toolName) match
case Some(tool) =>
def inputSnippet = args.noSpaces.take(200) // for error reporting
// Use Circe's Decoder for argument decoding
def inputSnippet = args.noSpaces.take(200)
tool.inputDecoder.decodeJson(args) match
case Right(decodedInput) => handleDecodedInput(tool, decodedInput, id)
case Right(decodedInput) => handleDecodedInput(tool, decodedInput, id, headerValue)
case Left(decodingError) =>
protocolError(
id,
Expand All @@ -80,9 +80,11 @@ class McpHandler[F[_]](tools: List[ServerTool[?, F]], name: String = "Chimp MCP
protocolError(id, JSONRPCErrorCodes.InvalidParams.code, "Missing tool name").unit

/** Handles a successfully decoded tool input, dispatching to the tool's logic. */
private def handleDecodedInput[T](tool: ServerTool[T, F], decodedInput: T, id: RequestId)(using MonadError[F]): F[JSONRPCMessage] =
private def handleDecodedInput[T](tool: ServerTool[T, F], decodedInput: T, id: RequestId, headerValue: Option[String])(using
MonadError[F]
): F[JSONRPCMessage] =
tool
.logic(decodedInput)
.logic(decodedInput, headerValue)
.map:
case Right(result) =>
val callResult = ToolCallResult(
Expand All @@ -97,28 +99,26 @@ class McpHandler[F[_]](tools: List[ServerTool[?, F]], name: String = "Chimp MCP
)
JSONRPCMessage.Response(id = id, result = callResult.asJson)

/** Handles a JSON-RPC request, dispatching to the appropriate handler. Logs requests and responses. */
def handleJsonRpc(request: Json)(using MonadError[F]): F[Json] =
def handleJsonRpc(request: Json, headerValue: Option[String])(using MonadError[F]): F[Json] =
logger.debug(s"Request: $request")
val responseF: F[JSONRPCMessage] = request.as[JSONRPCMessage] match
case Left(err) => protocolError(RequestId("null"), JSONRPCErrorCodes.ParseError.code, s"Parse error: ${err.message}").unit
case Right(JSONRPCMessage.Request(_, method, params: Option[io.circe.Json], id)) =>
method match
case "tools/list" => handleToolsList(id).unit
case "tools/call" => handleToolsCall(params, id)
case "tools/call" => handleToolsCall(params, id, headerValue)
case "initialize" => handleInitialize(id).unit
case other => protocolError(id, JSONRPCErrorCodes.MethodNotFound.code, s"Unknown method: $other").unit
case Right(JSONRPCMessage.BatchRequest(requests)) =>
// For each sub-request, process as a single request using flatMap/fold (no .sequence)
def processBatch(reqs: List[JSONRPCMessage], acc: List[JSONRPCMessage]): F[List[JSONRPCMessage]] =
reqs match
case Nil => acc.reverse.unit
case head :: tail =>
head match
case JSONRPCMessage.Notification(_, _, _) =>
processBatch(tail, acc) // skip notifications
processBatch(tail, acc)
case _ =>
handleJsonRpc(head.asJson).flatMap { respJson =>
handleJsonRpc(head.asJson, headerValue).flatMap { respJson =>
val msg = respJson
.as[JSONRPCMessage]
.getOrElse(
Expand All @@ -127,7 +127,6 @@ class McpHandler[F[_]](tools: List[ServerTool[?, F]], name: String = "Chimp MCP
processBatch(tail, msg :: acc)
}
processBatch(requests, Nil).map { responses =>
// Per JSON-RPC spec, notifications (no id) should not be included in the response
val filtered = responses.collect {
case r @ JSONRPCMessage.Response(_, id, _) => r
case e @ JSONRPCMessage.Error(_, id, _) => e
Expand Down
40 changes: 28 additions & 12 deletions core/src/main/scala/chimp/mcpEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,40 @@ private val logger = LoggerFactory.getLogger(classOf[McpHandler[_]])
* The list of tools to expose.
* @param path
* The path components at which to expose the MCP server.
* @param headerName
* The optional name of the header to read. If None, no header is read.
*
* @tparam F
* The effect type. Might be `Identity` for a endpoints with synchronous logic.
*/
def mcpEndpoint[F[_]](tools: List[ServerTool[?, F]], path: List[String]): ServerEndpoint[Any, F] =
def mcpEndpoint[F[_]](tools: List[ServerTool[?, F]], path: List[String], headerName: Option[String] = None): ServerEndpoint[Any, F] =
val mcpHandler = new McpHandler(tools)
val e = infallibleEndpoint.post
val base = infallibleEndpoint.post
.in(path.foldLeft(emptyInput)((inputSoFar, pathComponent) => inputSoFar / pathComponent))
.in(jsonBody[Json])
.out(jsonBody[Json])

ServerEndpoint.public(
e,
me =>
json =>
given MonadError[F] = me
mcpHandler
.handleJsonRpc(json)
.map: responseJson =>
Right(responseJson.deepDropNullValues)
)
headerName match {
case Some(name) =>
val endpoint = base.prependIn(header[Option[String]](name))
ServerEndpoint.public(
endpoint,
me => { (input: (Option[String], Json)) =>
val (headerValue, json) = input
given MonadError[F] = me
mcpHandler
.handleJsonRpc(json, headerValue)
.map(responseJson => Right(responseJson.deepDropNullValues))
}
)
case None =>
ServerEndpoint.public(
base,
me => { (json: Json) =>
given MonadError[F] = me
mcpHandler
.handleJsonRpc(json, None)
.map(responseJson => Right(responseJson.deepDropNullValues))
}
)
}
6 changes: 3 additions & 3 deletions core/src/main/scala/chimp/protocol/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// NOTE: RequestId and ProgressToken use newtype wrappers for spec accuracy and to avoid ambiguous implicits.
package chimp.protocol

import io.circe.syntax.*
import io.circe.{Codec, Decoder, Encoder, Json}
import io.circe.syntax._

// --- JSON-RPC base types ---
// Use newtype wrappers for union types to avoid ambiguous implicits
Expand Down Expand Up @@ -46,8 +46,8 @@ enum JSONRPCMessage:
case BatchResponse(responses: List[JSONRPCMessage])

object JSONRPCMessage {
import io.circe._
import io.circe.syntax._
import io.circe.*
import io.circe.syntax.*

given Decoder[JSONRPCMessage] = Decoder.instance { c =>
val jsonrpc = c.downField("jsonrpc").as[String].getOrElse("2.0")
Expand Down
11 changes: 6 additions & 5 deletions core/src/main/scala/chimp/tool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@ case class Tool[I](
/** Combine the tool description with the server logic, that should be executed when the tool is invoked. The logic, given the input,
* should return either a tool execution error (`Left`), or a successful textual result (`Right`), using the F-effect.
*/
def serverLogic[F[_]](logic: I => F[Either[String, String]]): ServerTool[I, F] =
def serverLogic[F[_]](logic: (I, Option[String]) => F[Either[String, String]]): ServerTool[I, F] =
ServerTool(name, description, inputSchema, inputDecoder, annotations, logic)

/** Combine the tool description with the server logic, that should be executed when the tool is invoked. The logic, given the input,
* should return either a tool execution error (`Left`), or a successful textual result (`Right`).
*
* Same as [[serverLogic]], but using the identity "effect".
*/
def handle(logic: I => Either[String, String]): ServerTool[I, Identity] =
ServerTool(name, description, inputSchema, inputDecoder, annotations, logic)
def handleWithHeader(logic: (I, Option[String]) => Either[String, String]): ServerTool[I, Identity] =
ServerTool(name, description, inputSchema, inputDecoder, annotations, (i, t) => logic(i, t))

//
def handle(logic: I => Either[String, String]): ServerTool[I, Identity] =
handleWithHeader((i, _) => logic(i))

/** A tool that can be executed by the MCP server. */
case class ServerTool[I, F[_]](
Expand All @@ -60,5 +61,5 @@ case class ServerTool[I, F[_]](
inputSchema: Schema[I],
inputDecoder: Decoder[I],
annotations: Option[ToolAnnotations],
logic: I => F[Either[String, String]]
logic: (I, Option[String]) => F[Either[String, String]]
)
Loading