Skip to content

Commit 27fb435

Browse files
committed
refactor: stop rewriting paths in proxy
This change also properly handles redirects from the downstream MCP server and rewrites them to point to the proxy. Signed-off-by: Donnie Adams <donnie@acorn.io>
1 parent 8333411 commit 27fb435

File tree

5 files changed

+50
-32
lines changed

5 files changed

+50
-32
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
271271
OAUTH_CLIENT_ID: "your-oauth-client-id"
272272
OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
273273
OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
274-
MCP_SERVER_URL: "http://localhost:3000/mcp"
274+
MCP_SERVER_URL: "http://localhost:3000"
275275
ENCRYPTION_KEY: "your-base64-encoded-32-byte-key"
276276
ports:
277277
- "8080:8080"
@@ -295,7 +295,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
295295
OAUTH_CLIENT_ID: "your-oauth-client-id"
296296
OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
297297
OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
298-
MCP_SERVER_URL: "http://localhost:3000/mcp"
298+
MCP_SERVER_URL: "http://localhost:3000"
299299
volumes:
300300
- ./data:/app/data # Persist SQLite database
301301
ports:
@@ -342,7 +342,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
342342

343343
### MCP Proxy
344344

345-
- `ANY /mcp/*` - Proxies requests to MCP server with user context headers
345+
- `ANY /*` - Proxies any request not mentioned above to MCP server with user context headers
346346

347347
## OAuth Flow
348348

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ services:
1919
# OAUTH_CLIENT_ID: "your-oauth-client-id"
2020
# OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
2121
# OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
22-
# MCP_SERVER_URL: "http://localhost:3000/mcp"
22+
# MCP_SERVER_URL: "http://localhost:3000"
2323
# ports:
2424
# - "8080:8080"
2525
# depends_on:

main_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func TestIntegrationFlow(t *testing.T) {
9999
// Test protected resource metadata
100100
t.Run("ProtectedResourceMetadata", func(t *testing.T) {
101101
w := httptest.NewRecorder()
102-
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource/mcp", nil)
102+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
103103
handler.ServeHTTP(w, req)
104104

105105
assert.Equal(t, http.StatusOK, w.Code)

pkg/oauth/validate/validatetoken.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
2828
authHeader := r.Header.Get("Authorization")
2929
if authHeader == "" {
3030
// Return 401 with proper WWW-Authenticate header
31-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
31+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
3232
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="%s"`, resourceMetadataUrl)
3333
w.Header().Set("WWW-Authenticate", wwwAuthValue)
3434
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -41,7 +41,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
4141
// Parse Authorization header
4242
parts := strings.SplitN(authHeader, " ", 2)
4343
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" || parts[1] == "" {
44-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
44+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
4545
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Invalid Authorization header format, expected 'Bearer TOKEN'", resource_metadata="%s"`, resourceMetadataUrl)
4646
w.Header().Set("WWW-Authenticate", wwwAuthValue)
4747
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -55,7 +55,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
5555

5656
tokenInfo, err := p.tokenManager.GetTokenInfo(token)
5757
if err != nil {
58-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
58+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
5959
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Invalid or expired token", resource_metadata="%s"`, resourceMetadataUrl)
6060
w.Header().Set("WWW-Authenticate", wwwAuthValue)
6161
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -69,7 +69,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
6969
if tokenInfo.Props != nil {
7070
decryptedProps, err := encryption.DecryptPropsIfNeeded(p.encryptionKey, tokenInfo.Props)
7171
if err != nil {
72-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
72+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
7373
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Failed to decrypt token data", resource_metadata="%s"`, resourceMetadataUrl)
7474
w.Header().Set("WWW-Authenticate", wwwAuthValue)
7575
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{

pkg/proxy/proxy.go

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ func LoadConfigFromEnv() *types.Config {
6161
}
6262

6363
func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) {
64+
if u, err := url.Parse(config.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" {
65+
return nil, fmt.Errorf("invalid MCP server URL: %w", err)
66+
} else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" {
67+
return nil, fmt.Errorf("MCP server URL must not contain a path, query, or fragment")
68+
}
6469
databaseDSN := config.DatabaseDSN
6570

6671
// Log database configuration
@@ -194,22 +199,21 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
194199
revokeHandler := revoke.NewHandler(p.db)
195200
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.encryptionKey)
196201

197-
mux.HandleFunc("/health", p.withCORS(p.healthHandler))
202+
mux.HandleFunc("GET /health", p.withCORS(p.healthHandler))
198203

199204
// OAuth endpoints
200-
mux.HandleFunc("/authorize", p.withCORS(p.withRateLimit(authorizeHandler)))
201-
mux.HandleFunc("/callback", p.withCORS(p.withRateLimit(callbackHandler)))
202-
mux.HandleFunc("/token", p.withCORS(p.withRateLimit(tokenHandler)))
203-
mux.HandleFunc("/revoke", p.withCORS(p.withRateLimit(revokeHandler)))
204-
mux.HandleFunc("/register", p.withCORS(p.withRateLimit(register.NewHandler(p.db))))
205+
mux.HandleFunc("GET /authorize", p.withCORS(p.withRateLimit(authorizeHandler)))
206+
mux.HandleFunc("GET /callback", p.withCORS(p.withRateLimit(callbackHandler)))
207+
mux.HandleFunc("POST /token", p.withCORS(p.withRateLimit(tokenHandler)))
208+
mux.HandleFunc("POST /revoke", p.withCORS(p.withRateLimit(revokeHandler)))
209+
mux.HandleFunc("POST /register", p.withCORS(p.withRateLimit(register.NewHandler(p.db))))
205210

206211
// Metadata endpoints
207-
mux.HandleFunc("/.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler))
208-
mux.HandleFunc("/.well-known/oauth-protected-resource/mcp", p.withCORS(p.protectedResourceMetadataHandler))
212+
mux.HandleFunc("GET /.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler))
213+
mux.HandleFunc("GET /.well-known/oauth-protected-resource", p.withCORS(p.protectedResourceMetadataHandler))
209214

210-
// Protected resource endpoints
211-
mux.HandleFunc("/mcp", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
212-
mux.HandleFunc("/mcp/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
215+
// Protect everything else
216+
mux.HandleFunc("/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
213217
}
214218

215219
// GetHandler returns an http.Handler for the OAuth proxy
@@ -289,9 +293,10 @@ func (p *OAuthProxy) oauthMetadataHandler(w http.ResponseWriter, r *http.Request
289293
}
290294

291295
func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r *http.Request) {
296+
baseURL := handlerutils.GetBaseURL(r)
292297
metadata := types.OAuthProtectedResourceMetadata{
293-
Resource: fmt.Sprintf("%s/mcp", handlerutils.GetBaseURL(r)),
294-
AuthorizationServers: []string{handlerutils.GetBaseURL(r)},
298+
Resource: baseURL,
299+
AuthorizationServers: []string{baseURL},
295300
Scopes: p.metadata.ScopesSupported,
296301
ResourceName: p.resourceName,
297302
ResourceDocumentation: p.metadata.ServiceDocumentation,
@@ -387,15 +392,7 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
387392
}
388393

389394
// Create target URL
390-
var targetURL string
391-
if path == "" {
392-
// If no path is provided, use the MCP server URL directly
393-
targetURL = p.GetMCPServerURL()
394-
} else {
395-
// If path is provided, append it to the MCP server URL
396-
targetURL = p.GetMCPServerURL() + "/" + path
397-
}
398-
395+
targetURL := p.GetMCPServerURL() + "/" + path
399396
// Log the proxy request for debugging
400397
log.Printf("Proxying request: %s %s -> %s", r.Method, r.URL.Path, targetURL)
401398

@@ -404,12 +401,12 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
404401
Director: func(req *http.Request) {
405402
req.Header.Del("Authorization")
406403
req.Header.Set("X-Forwarded-Host", req.Host)
404+
req.Header.Set("X-Forwarded-Proto", req.URL.Scheme)
407405

408406
newURL, _ := url.Parse(targetURL)
409407
req.URL.Scheme = newURL.Scheme
410408
req.URL.Host = newURL.Host
411409
req.Host = newURL.Host
412-
req.URL.Path = newURL.Path
413410

414411
// Add forwarded headers from token props
415412
if tokenInfo.Props != nil {
@@ -427,6 +424,27 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
427424
}
428425
}
429426
},
427+
ModifyResponse: func(resp *http.Response) error {
428+
// Rewrite Location header to use proxy host instead of downstream server host
429+
if location := resp.Header.Get("Location"); location != "" {
430+
if locationURL, err := url.Parse(location); err == nil {
431+
// Get the original request to extract proxy host
432+
proxyHost := resp.Request.Header.Get("X-Forwarded-Host")
433+
if proxyHost != "" {
434+
// Parse downstream server URL to get scheme
435+
downstreamURL, _ := url.Parse(p.GetMCPServerURL())
436+
437+
// Only rewrite if the location points to the downstream server
438+
if locationURL.Host == downstreamURL.Host {
439+
locationURL.Scheme = resp.Request.URL.Scheme
440+
locationURL.Host = proxyHost
441+
resp.Header.Set("Location", locationURL.String())
442+
}
443+
}
444+
}
445+
}
446+
return nil
447+
},
430448
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
431449
log.Printf("Proxy error: %v", err)
432450
rw.WriteHeader(http.StatusBadGateway)

0 commit comments

Comments
 (0)