Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cmd/extproc/mainlib/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
chatCompletionMetrics := metrics.NewChatCompletion(meter, metricsRequestHeaderAttributes)
completionMetrics := metrics.NewCompletion(meter, metricsRequestHeaderAttributes)
embeddingsMetrics := metrics.NewEmbeddings(meter, metricsRequestHeaderAttributes)
mcpMetrics := metrics.NewMCP(meter)
mcpMetrics := metrics.NewMCP(meter, metricsRequestHeaderAttributes)

tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes)
if err != nil {
Expand Down Expand Up @@ -264,13 +264,13 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
seed, fallbackSeed, _ := strings.Cut(flags.mcpSessionEncryptionSeed, ",")
mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(seed, fallbackSeed)
var mcpProxyMux *http.ServeMux
var mcpProxy *mcpproxy.MCPProxy
mcpProxy, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics,
var mcpProxyConfig *mcpproxy.ProxyConfig
mcpProxyConfig, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics,
tracing.MCPTracer(), mcpSessionCrypto)
if err != nil {
return fmt.Errorf("failed to create MCP proxy: %w", err)
}
if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxy, l, time.Second*5); err != nil {
if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxyConfig, l, time.Second*5); err != nil {
return fmt.Errorf("failed to start config watcher: %w", err)
}

Expand Down
36 changes: 36 additions & 0 deletions internal/lang/maps.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright Envoy AI Gateway Authors
// SPDX-License-Identifier: Apache-2.0
// The full text of the Apache license is available in the LICENSE file at
// the root of the repo.

package lang

import (
"fmt"
"maps"
"slices"
"strings"
)

// CaseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner.
// If the same key is present in different cases, the first one in alphabetical order
// that matches is returned.
// If the key is not found, it returns an empty string.
func CaseInsensitiveValue(m map[string]any, key string) string {
if m == nil {
return ""
}

if v, ok := m[key]; ok {
return fmt.Sprintf("%v", v)
}

keys := slices.Sorted(maps.Keys(m))
for _, k := range keys {
if strings.EqualFold(k, key) {
return fmt.Sprintf("%v", m[k])
}
}

return ""
}
57 changes: 57 additions & 0 deletions internal/lang/maps_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright Envoy AI Gateway Authors
// SPDX-License-Identifier: Apache-2.0
// The full text of the Apache license is available in the LICENSE file at
// the root of the repo.

package lang

import "testing"

func TestCaseInsensitiveValue(t *testing.T) {
tests := []struct {
name string
m map[string]any
key string
want string
}{
{
name: "nil map",
m: nil,
key: "anything",
want: "",
},
{
name: "exact match returns value",
m: map[string]any{"Foo": "bar", "foo": "should-not-be-used"},
key: "Foo",
want: "bar",
},
{
name: "case-insensitive match when exact not present",
m: map[string]any{"FOO": "baz"},
key: "foo",
want: "baz",
},
{
name: "multiple case variants - alphabetical first chosen",
m: map[string]any{"ALPHA": 2, "Alpha": 1},
key: "alpha",
want: "2", // ALPHA is alphabetically first
},
{
name: "nil value formatted",
m: map[string]any{"key": nil},
key: "key",
want: "<nil>",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := CaseInsensitiveValue(tc.m, tc.key)
if got != tc.want {
t.Fatalf("CaseInsensitiveValue(%v, %q) = %q; want %q", tc.m, tc.key, got, tc.want)
}
})
}
}
110 changes: 57 additions & 53 deletions internal/mcpproxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
errType metrics.MCPErrorType
requestMethod string
span tracing.MCPSpan
params mcp.Params
)
defer func() {
if m.l.Enabled(ctx, slog.LevelDebug) {
Expand All @@ -119,17 +120,17 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
if span != nil {
span.EndSpanOnError(string(errType), err)
}
m.metrics.RecordMethodErrorCount(ctx)
m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType)
m.metrics.RecordMethodErrorCount(ctx, params)
m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType, params)
return
}

if span != nil {
span.EndSpan()
}
m.metrics.RecordRequestDuration(ctx, &startAt)
m.metrics.RecordRequestDuration(ctx, &startAt, params)
// TODO: should we special case when this request is "Response" where method is empty?
m.metrics.RecordMethodCount(ctx, requestMethod)
m.metrics.RecordMethodCount(ctx, requestMethod, params)
}()
if sessionID := r.Header.Get(sessionIDHeader); sessionID != "" {
s, err = m.sessionFromID(secureClientToGatewaySessionID(sessionID), secureClientToGatewayEventID(r.Header.Get(lastEventIDHeader)))
Expand Down Expand Up @@ -189,37 +190,37 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {

switch msg.Method {
case "notifications/roots/list_changed":
p := &mcp.RootsListChangedParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.RootsListChangedParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleNotificationsRootsListChanged(ctx, s, w, msg, span)
case "completion/complete":
p := &mcp.CompleteParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.CompleteParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleCompletionComplete(ctx, s, w, msg, p, span)
err = m.handleCompletionComplete(ctx, s, w, msg, params.(*mcp.CompleteParams), span)
case "notifications/progress":
m.metrics.RecordProgress(ctx)
p := &mcp.ProgressNotificationParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ProgressNotificationParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
m.metrics.RecordProgress(ctx, params)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, p, span)
err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, params.(*mcp.ProgressNotificationParams), span)
case "initialize":
// The very first request from the client to establish a session.
p := &mcp.InitializeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.InitializeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
m.l.Error("Failed to unmarshal initialize params", slog.String("error", err.Error()))
Expand All @@ -235,107 +236,107 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
onErrorResponse(w, http.StatusInternalServerError, "missing route header")
return
}
err = m.handleInitializeRequest(ctx, w, msg, p, route, extractSubject(r), span)
err = m.handleInitializeRequest(ctx, w, msg, params.(*mcp.InitializeParams), route, extractSubject(r), span)
case "notifications/initialized":
// According to the MCP spec, when the server receives a JSON-RPC response or notification from the client
// and accepts it, the server MUST return HTTP 202 Accepted with an empty body.
// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
w.WriteHeader(http.StatusAccepted)
case "logging/setLevel":
p := &mcp.SetLoggingLevelParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.SetLoggingLevelParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
m.l.Error("Failed to unmarshal set logging level params", slog.String("error", err.Error()))
onErrorResponse(w, http.StatusBadRequest, "invalid set logging level params")
return
}
err = m.handleSetLoggingLevel(ctx, s, w, msg, p, span)
err = m.handleSetLoggingLevel(ctx, s, w, msg, params.(*mcp.SetLoggingLevelParams), span)
case "ping":
// Ping is intentionally not traced as it's a lightweight health check.
err = m.handlePing(ctx, w, msg)
case "prompts/list":
p := &mcp.ListPromptsParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ListPromptsParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handlePromptListRequest(ctx, s, w, msg, p, span)
err = m.handlePromptListRequest(ctx, s, w, msg, params.(*mcp.ListPromptsParams), span)
case "prompts/get":
p := &mcp.GetPromptParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.GetPromptParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handlePromptGetRequest(ctx, s, w, msg, p)
err = m.handlePromptGetRequest(ctx, s, w, msg, params.(*mcp.GetPromptParams))
case "tools/call":
p := &mcp.CallToolParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.CallToolParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error()))
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleToolCallRequest(ctx, s, w, msg, p, span)
err = m.handleToolCallRequest(ctx, s, w, msg, params.(*mcp.CallToolParams), span)
case "tools/list":
p := &mcp.ListToolsParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ListToolsParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleToolsListRequest(ctx, s, w, msg, p, span)
err = m.handleToolsListRequest(ctx, s, w, msg, params.(*mcp.ListToolsParams), span)
case "resources/list":
p := &mcp.ListResourcesParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ListResourcesParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleResourceListRequest(ctx, s, w, msg, p, span)
err = m.handleResourceListRequest(ctx, s, w, msg, params.(*mcp.ListResourcesParams), span)
case "resources/read":
p := &mcp.ReadResourceParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ReadResourceParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleResourceReadRequest(ctx, s, w, msg, p)
err = m.handleResourceReadRequest(ctx, s, w, msg, params.(*mcp.ReadResourceParams))
case "resources/templates/list":
p := &mcp.ListResourceTemplatesParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.ListResourceTemplatesParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, p, span)
err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, params.(*mcp.ListResourceTemplatesParams), span)
case "resources/subscribe":
p := &mcp.SubscribeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.SubscribeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, p, span)
err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, params.(*mcp.SubscribeParams), span)
case "resources/unsubscribe":
p := &mcp.UnsubscribeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
params = &mcp.UnsubscribeParams{}
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
if err != nil {
errType = metrics.MCPErrorInvalidParam
onErrorResponse(w, http.StatusBadRequest, "invalid params")
return
}
err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, p, span)
err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, params.(*mcp.UnsubscribeParams), span)
case "notifications/cancelled":
// The responsibility of cancelling the operation on server side is optional, so we just ignore it for now.
// https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/cancellation#behavior-requirements
Expand Down Expand Up @@ -371,8 +372,7 @@ func errorType(err error) metrics.MCPErrorType {

// handleInitializeRequest handles the "initialize" JSON-RPC method.
func (m *MCPProxy) handleInitializeRequest(ctx context.Context, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.InitializeParams, route, subject string, span tracing.MCPSpan) error {
m.metrics.RecordClientCapabilities(ctx, p.Capabilities)

m.metrics.RecordClientCapabilities(ctx, p.Capabilities, p)
s, err := m.newSession(ctx, p, route, subject, span)
if err != nil {
m.l.Error("failed to create new session", slog.String("error", err.Error()))
Expand Down Expand Up @@ -789,19 +789,23 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) {
case "notifications/resources/list_changed":
case "notifications/resources/updated":
case "notifications/progress":
m.metrics.RecordProgress(ctx)
params := &mcp.ProgressNotificationParams{}
if err := json.Unmarshal(msg.Params, &params); err != nil {
m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error()))
}
m.metrics.RecordProgress(ctx, params)
case "notifications/message":
case "notifications/tools/list_changed":
case "roots/list":
case "sampling/createMessage":
case "elicitation/create":
default:
knownMethod = false
m.metrics.RecordMethodErrorCount(ctx)
m.metrics.RecordMethodErrorCount(ctx, nil)
m.l.Warn("Unsupported MCP request method from server", slog.String("method", msg.Method))
}
if knownMethod {
m.metrics.RecordMethodCount(ctx, msg.Method)
m.metrics.RecordMethodCount(ctx, msg.Method, nil)
}
default:
m.l.Warn("unexpected message type in MCP response", slog.Any("message", msg))
Expand Down Expand Up @@ -1223,7 +1227,7 @@ func sendToAllBackendsAndAggregateResponsesImpl[responseType any](ctx context.Co
}

// parseParamsAndMaybeStartSpan parses the params from the JSON-RPC request and starts a tracing span if params is non-nil.
func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType) (tracing.MCPSpan, error) {
func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType, headers http.Header) (tracing.MCPSpan, error) {
if req.Params == nil {
return nil, nil
}
Expand All @@ -1233,7 +1237,7 @@ func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *
return nil, err
}

span := m.tracer.StartSpanAndInjectMeta(ctx, req, p)
span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headers)
return span, nil
}

Expand Down
Loading