Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions internal/server/http_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package server

import (
"bytes"
"context"
"io"
"log"
"net/http"

"github.com/githubnext/gh-aw-mcpg/internal/auth"
"github.com/githubnext/gh-aw-mcpg/internal/logger"
"github.com/githubnext/gh-aw-mcpg/internal/mcp"
)

// extractAndValidateSession extracts the session ID from the Authorization header
// and logs connection details. Returns empty string if validation fails.
func extractAndValidateSession(r *http.Request) string {
authHeader := r.Header.Get("Authorization")
sessionID := auth.ExtractSessionID(authHeader)

if sessionID == "" {
logger.LogError("client", "Rejected MCP client connection: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path)
log.Printf("[%s] %s %s - REJECTED: No Authorization header", r.RemoteAddr, r.Method, r.URL.Path)
return ""
}

return sessionID
}

// logHTTPRequestBody logs the request body for debugging purposes.
// It reads the body, logs it, and restores it so it can be read again.
// The backendID parameter is optional and can be empty for unified mode.
func logHTTPRequestBody(r *http.Request, sessionID, backendID string) {
if r.Method != "POST" || r.Body == nil {
return
}

bodyBytes, err := io.ReadAll(r.Body)
if err != nil || len(bodyBytes) == 0 {
return
}

// Log with backend context if provided (routed mode)
if backendID != "" {
logger.LogDebug("client", "MCP client request body, backend=%s, body=%s", backendID, string(bodyBytes))
} else {
logger.LogDebug("client", "MCP request body, session=%s, body=%s", sessionID, string(bodyBytes))
}
log.Printf("Request body: %s", string(bodyBytes))

// Restore body for subsequent reads
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}

// injectSessionContext stores the session ID and optional backend ID into the request context.
// If backendID is empty, only session ID is injected (unified mode).
// Returns the modified request with updated context.
func injectSessionContext(r *http.Request, sessionID, backendID string) *http.Request {
ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID)

if backendID != "" {
ctx = context.WithValue(ctx, mcp.ContextKey("backend-id"), backendID)
}

return r.WithContext(ctx)
}
225 changes: 225 additions & 0 deletions internal/server/http_helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package server

import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/githubnext/gh-aw-mcpg/internal/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestExtractAndValidateSession(t *testing.T) {
tests := []struct {
name string
authHeader string
expectedID string
shouldBeEmpty bool
}{
{
name: "Valid plain API key",
authHeader: "test-session-123",
expectedID: "test-session-123",
shouldBeEmpty: false,
},
{
name: "Valid Bearer token",
authHeader: "Bearer my-token-456",
expectedID: "my-token-456",
shouldBeEmpty: false,
},
{
name: "Empty Authorization header",
authHeader: "",
expectedID: "",
shouldBeEmpty: true,
},
{
name: "Whitespace only header",
authHeader: " ",
expectedID: " ",
shouldBeEmpty: false,
},
{
name: "Long session ID",
authHeader: "very-long-session-id-with-many-characters-1234567890",
expectedID: "very-long-session-id-with-many-characters-1234567890",
shouldBeEmpty: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/mcp", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}

sessionID := extractAndValidateSession(req)

if tt.shouldBeEmpty {
assert.Empty(t, sessionID, "Expected empty session ID")
} else {
assert.Equal(t, tt.expectedID, sessionID, "Session ID mismatch")
}
})
}
}

func TestLogHTTPRequestBody(t *testing.T) {
tests := []struct {
name string
method string
body string
sessionID string
backendID string
shouldLog bool
}{
{
name: "POST request with body and backend",
method: "POST",
body: `{"method":"initialize"}`,
sessionID: "session-123",
backendID: "backend-1",
shouldLog: true,
},
{
name: "POST request with body without backend",
method: "POST",
body: `{"method":"tools/call"}`,
sessionID: "session-456",
backendID: "",
shouldLog: true,
},
{
name: "GET request (no body logging)",
method: "GET",
body: "",
sessionID: "session-789",
backendID: "backend-2",
shouldLog: false,
},
{
name: "POST request with empty body",
method: "POST",
body: "",
sessionID: "session-abc",
backendID: "backend-3",
shouldLog: false,
},
{
name: "POST request with nil body",
method: "POST",
body: "",
sessionID: "session-def",
backendID: "",
shouldLog: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req *http.Request
if tt.body != "" {
req = httptest.NewRequest(tt.method, "/mcp", bytes.NewBufferString(tt.body))
} else if tt.method == "POST" {
req = httptest.NewRequest(tt.method, "/mcp", nil)
} else {
req = httptest.NewRequest(tt.method, "/mcp", nil)
}

// Call the function
logHTTPRequestBody(req, tt.sessionID, tt.backendID)

// Verify body can still be read after logging
if tt.body != "" {
bodyBytes, err := io.ReadAll(req.Body)
require.NoError(t, err, "Should be able to read body after logging")
assert.Equal(t, tt.body, string(bodyBytes), "Body content should be preserved")
}
})
}
}

func TestInjectSessionContext(t *testing.T) {
tests := []struct {
name string
sessionID string
backendID string
expectBackendID bool
}{
{
name: "Inject session and backend ID (routed mode)",
sessionID: "session-123",
backendID: "github",
expectBackendID: true,
},
{
name: "Inject session ID only (unified mode)",
sessionID: "session-456",
backendID: "",
expectBackendID: false,
},
{
name: "Long session ID with backend",
sessionID: "very-long-session-id-1234567890",
backendID: "slack",
expectBackendID: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/mcp", nil)

// Inject context
modifiedReq := injectSessionContext(req, tt.sessionID, tt.backendID)

// Verify session ID is in context
sessionIDFromCtx := modifiedReq.Context().Value(SessionIDContextKey)
require.NotNil(t, sessionIDFromCtx, "Session ID should be in context")
assert.Equal(t, tt.sessionID, sessionIDFromCtx, "Session ID mismatch")

// Verify backend ID if expected
if tt.expectBackendID {
backendIDFromCtx := modifiedReq.Context().Value(mcp.ContextKey("backend-id"))
require.NotNil(t, backendIDFromCtx, "Backend ID should be in context")
assert.Equal(t, tt.backendID, backendIDFromCtx, "Backend ID mismatch")
} else {
backendIDFromCtx := modifiedReq.Context().Value(mcp.ContextKey("backend-id"))
assert.Nil(t, backendIDFromCtx, "Backend ID should not be in context for unified mode")
}

// Verify original request is not modified
originalSessionID := req.Context().Value(SessionIDContextKey)
assert.Nil(t, originalSessionID, "Original request context should not be modified")
})
}
}

// testContextKey is a custom type for context keys to avoid collisions
type testContextKey string

func TestInjectSessionContext_PreservesExistingContext(t *testing.T) {
// Create a request with existing context values
req := httptest.NewRequest("POST", "/mcp", nil)
ctx := context.WithValue(req.Context(), testContextKey("existing-key"), "existing-value")
req = req.WithContext(ctx)

// Inject session context
modifiedReq := injectSessionContext(req, "session-123", "backend-1")

// Verify both values are present
sessionID := modifiedReq.Context().Value(SessionIDContextKey)
assert.Equal(t, "session-123", sessionID, "Session ID should be present")

backendID := modifiedReq.Context().Value(mcp.ContextKey("backend-id"))
assert.Equal(t, "backend-1", backendID, "Backend ID should be present")

existingValue := modifiedReq.Context().Value(testContextKey("existing-key"))
assert.Equal(t, "existing-value", existingValue, "Existing context value should be preserved")
}
26 changes: 4 additions & 22 deletions internal/server/routed.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package server

import (
"bytes"
"context"
"fmt"
"io"
"log"
"net/http"
"sync"

"github.com/githubnext/gh-aw-mcpg/internal/auth"
"github.com/githubnext/gh-aw-mcpg/internal/logger"
"github.com/githubnext/gh-aw-mcpg/internal/mcp"
sdk "github.com/modelcontextprotocol/go-sdk/mcp"
)

Expand Down Expand Up @@ -110,14 +106,9 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap

// Create StreamableHTTP handler for this route
routeHandler := sdk.NewStreamableHTTPHandler(func(r *http.Request) *sdk.Server {
// Extract session ID from Authorization header
authHeader := r.Header.Get("Authorization")
sessionID := auth.ExtractSessionID(authHeader)

// Reject requests without Authorization header
// Extract and validate session ID from Authorization header
sessionID := extractAndValidateSession(r)
if sessionID == "" {
logger.LogError("client", "Rejected MCP client connection: no Authorization header, remote=%s, path=%s", r.RemoteAddr, r.URL.Path)
log.Printf("[%s] %s %s - REJECTED: No Authorization header", r.RemoteAddr, r.Method, r.URL.Path)
return nil
}

Expand All @@ -129,19 +120,10 @@ func CreateHTTPServerForRoutedMode(addr string, unifiedServer *UnifiedServer, ap
log.Printf("Authorization (Session ID): %s", sessionID)

// Log request body for debugging
if r.Method == "POST" && r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err == nil && len(bodyBytes) > 0 {
logger.LogDebug("client", "MCP client request body, backend=%s, body=%s", backendID, string(bodyBytes))
log.Printf("Request body: %s", string(bodyBytes))
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}
logHTTPRequestBody(r, sessionID, backendID)

// Store session ID and backend ID in request context
ctx := context.WithValue(r.Context(), SessionIDContextKey, sessionID)
ctx = context.WithValue(ctx, mcp.ContextKey("backend-id"), backendID)
*r = *r.WithContext(ctx)
*r = *injectSessionContext(r, sessionID, backendID)
log.Printf("✓ Injected session ID and backend ID into context")
log.Printf("===================================\n")

Expand Down
Loading