Skip to content

Commit 840b41e

Browse files
committed
implement better resource path handling for OAuth server
1 parent cfea762 commit 840b41e

File tree

6 files changed

+193
-202
lines changed

6 files changed

+193
-202
lines changed

cmd/github-mcp-server/main.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ var (
102102
Host: viper.GetString("host"),
103103
Port: viper.GetInt("port"),
104104
BaseURL: viper.GetString("base-url"),
105+
ResourcePath: viper.GetString("resource-path"),
105106
ExportTranslations: viper.GetBool("export-translations"),
106107
EnableCommandLogging: viper.GetBool("enable-command-logging"),
107108
LogFilePath: viper.GetString("log-file"),
@@ -137,6 +138,7 @@ func init() {
137138
rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)")
138139
rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port")
139140
rootCmd.PersistentFlags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)")
141+
rootCmd.PersistentFlags().String("resource-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)")
140142

141143
// Bind flag to viper
142144
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
@@ -154,6 +156,7 @@ func init() {
154156
_ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl"))
155157
_ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
156158
_ = viper.BindPFlag("base-url", rootCmd.PersistentFlags().Lookup("base-url"))
159+
_ = viper.BindPFlag("resource-path", rootCmd.PersistentFlags().Lookup("resource-path"))
157160

158161
// Add subcommands
159162
rootCmd.AddCommand(stdioCmd)

pkg/http/handler.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ func NewHTTPMcpHandler(
9393
// RegisterRoutes registers the routes for the MCP server
9494
// URL-based values take precedence over header-based values
9595
func (h *Handler) RegisterRoutes(r chi.Router) {
96-
r.Use(middleware.WithRequestConfig)
96+
mcpRouter := chi.NewRouter()
97+
mcpRouter.Use(middleware.WithRequestConfig)
9798

98-
r.Mount("/", h)
99+
mcpRouter.Mount("/", h)
99100
// Mount readonly and toolset routes
100-
r.With(withToolset).Mount("/x/{toolset}", h)
101-
r.With(withReadonly, withToolset).Mount("/x/{toolset}/readonly", h)
102-
r.With(withReadonly).Mount("/readonly", h)
101+
mcpRouter.With(withToolset).Mount("/x/{toolset}", h)
102+
mcpRouter.With(withReadonly, withToolset).Mount("/x/{toolset}/readonly", h)
103+
mcpRouter.With(withReadonly).Mount("/readonly", h)
104+
105+
r.Mount("/", mcpRouter)
103106
}
104107

105108
// withReadonly is middleware that sets readonly mode in the request context

pkg/http/middleware/token.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handl
6767
// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header
6868
// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec.
6969
func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) {
70-
resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, "mcp")
70+
resourcePath := oauth.ResolveResourcePath(r, oauthCfg)
71+
resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath)
7172
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL))
7273
http.Error(w, "Unauthorized", http.StatusUnauthorized)
7374
}

pkg/http/oauth/oauth.go

Lines changed: 124 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
package oauth
44

55
import (
6+
"encoding/json"
67
"fmt"
78
"net/http"
8-
"net/url"
99
"strings"
1010

1111
"github.com/github/github-mcp-server/pkg/http/headers"
1212
"github.com/go-chi/chi/v5"
13-
"github.com/modelcontextprotocol/go-sdk/auth"
1413
"github.com/modelcontextprotocol/go-sdk/oauthex"
1514
)
1615

@@ -48,17 +47,12 @@ type Config struct {
4847
// Defaults to GitHub's OAuth server if not specified.
4948
AuthorizationServer string
5049

51-
// ResourcePath is the resource path suffix (e.g., "/mcp").
52-
// If empty, defaults to "/"
50+
// ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp").
51+
// This is used to restore the original path when a proxy strips a base path before forwarding.
52+
// If empty, requests are treated as already using the external path.
5353
ResourcePath string
5454
}
5555

56-
// ProtectedResourceData contains the data needed to build an OAuth protected resource response.
57-
type ProtectedResourceData struct {
58-
ResourceURL string
59-
AuthorizationServer string
60-
}
61-
6256
// AuthHandler handles OAuth-related HTTP endpoints.
6357
type AuthHandler struct {
6458
cfg *Config
@@ -94,119 +88,165 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) {
9488
for _, pattern := range routePatterns {
9589
for _, route := range h.routesForPattern(pattern) {
9690
path := OAuthProtectedResourcePrefix + route
97-
98-
// Build metadata for this specific resource path
99-
metadata := h.buildMetadata(route)
100-
r.Handle(path, auth.ProtectedResourceMetadataHandler(metadata))
91+
r.Handle(path, h.metadataHandler())
10192
}
10293
}
10394
}
10495

105-
func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResourceMetadata {
106-
baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/")
107-
resourceURL := baseURL
108-
if resourcePath != "" && resourcePath != "/" {
109-
resourceURL = baseURL + resourcePath
110-
}
96+
func (h *AuthHandler) metadataHandler() http.Handler {
97+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
98+
// CORS headers for browser-based clients
99+
w.Header().Set("Access-Control-Allow-Origin", "*")
100+
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
101+
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
111102

112-
return &oauthex.ProtectedResourceMetadata{
113-
Resource: resourceURL,
114-
AuthorizationServers: []string{h.cfg.AuthorizationServer},
115-
ResourceName: "GitHub MCP Server",
116-
ScopesSupported: SupportedScopes,
117-
BearerMethodsSupported: []string{"header"},
118-
}
103+
if r.Method == http.MethodOptions {
104+
w.WriteHeader(http.StatusNoContent)
105+
return
106+
}
107+
if r.Method != http.MethodGet {
108+
w.WriteHeader(http.StatusMethodNotAllowed)
109+
return
110+
}
111+
112+
resourcePath := resolveResourcePath(
113+
strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix),
114+
h.cfg.ResourcePath,
115+
)
116+
resourceURL := h.buildResourceURL(r, resourcePath)
117+
118+
w.Header().Set("Content-Type", "application/json")
119+
_ = json.NewEncoder(w).Encode(&oauthex.ProtectedResourceMetadata{
120+
Resource: resourceURL,
121+
AuthorizationServers: []string{h.cfg.AuthorizationServer},
122+
ResourceName: "GitHub MCP Server",
123+
ScopesSupported: SupportedScopes,
124+
BearerMethodsSupported: []string{"header"},
125+
})
126+
})
119127
}
120128

121129
// routesForPattern generates route variants for a given pattern.
122130
// GitHub strips the /mcp prefix before forwarding, so we register both variants:
123131
// - With /mcp prefix: for direct access or when GitHub doesn't strip
124132
// - Without /mcp prefix: for when GitHub has stripped the prefix
125133
func (h *AuthHandler) routesForPattern(pattern string) []string {
126-
return []string{
127-
pattern,
128-
"/mcp" + pattern,
129-
pattern + "/",
130-
"/mcp" + pattern + "/",
134+
basePaths := []string{""}
135+
if basePath := normalizeBasePath(h.cfg.ResourcePath); basePath != "" {
136+
basePaths = append(basePaths, basePath)
137+
} else {
138+
basePaths = append(basePaths, "/mcp")
131139
}
140+
141+
routes := make([]string, 0, len(basePaths)*2)
142+
for _, basePath := range basePaths {
143+
routes = append(routes, joinRoute(basePath, pattern))
144+
routes = append(routes, joinRoute(basePath, pattern)+"/")
145+
}
146+
147+
return routes
148+
}
149+
150+
// resolveResourcePath returns the externally visible resource path,
151+
// restoring the configured base path when proxies strip it before forwarding.
152+
func resolveResourcePath(path, basePath string) string {
153+
if path == "" {
154+
path = "/"
155+
}
156+
base := normalizeBasePath(basePath)
157+
if base == "" {
158+
return path
159+
}
160+
if path == "/" {
161+
return base
162+
}
163+
if path == base || strings.HasPrefix(path, base+"/") {
164+
return path
165+
}
166+
return base + path
132167
}
133168

134-
// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs.
135-
// Since proxies may strip the /mcp prefix before forwarding requests, this function
136-
// restores the prefix for the external-facing URL.
137-
func GetEffectiveResourcePath(r *http.Request) string {
138-
if r.URL.Path == "/" {
139-
return "/mcp"
169+
// ResolveResourcePath returns the externally visible resource path for a request.
170+
// Exported for use by middleware.
171+
func ResolveResourcePath(r *http.Request, cfg *Config) string {
172+
basePath := ""
173+
if cfg != nil {
174+
basePath = cfg.ResourcePath
140175
}
141-
return "/mcp" + r.URL.Path
176+
return resolveResourcePath(r.URL.Path, basePath)
142177
}
143178

144-
// GetProtectedResourceData builds the OAuth protected resource data for a request.
145-
func (h *AuthHandler) GetProtectedResourceData(r *http.Request, resourcePath string) (*ProtectedResourceData, error) {
179+
// buildResourceURL constructs the full resource URL for OAuth metadata.
180+
func (h *AuthHandler) buildResourceURL(r *http.Request, resourcePath string) string {
146181
host, scheme := GetEffectiveHostAndScheme(r, h.cfg)
147-
148-
// Build the base URL
149182
baseURL := fmt.Sprintf("%s://%s", scheme, host)
150183
if h.cfg.BaseURL != "" {
151184
baseURL = strings.TrimSuffix(h.cfg.BaseURL, "/")
152185
}
153-
154-
// Build the resource URL using url.JoinPath for proper path handling
155-
var resourceURL string
156-
var err error
157-
if resourcePath == "/" {
158-
resourceURL = baseURL + "/"
159-
} else {
160-
resourceURL, err = url.JoinPath(baseURL, resourcePath)
161-
if err != nil {
162-
return nil, fmt.Errorf("failed to build resource URL: %w", err)
163-
}
186+
if resourcePath == "" {
187+
resourcePath = "/"
164188
}
165-
166-
return &ProtectedResourceData{
167-
ResourceURL: resourceURL,
168-
AuthorizationServer: h.cfg.AuthorizationServer,
169-
}, nil
189+
if !strings.HasPrefix(resourcePath, "/") {
190+
resourcePath = "/" + resourcePath
191+
}
192+
return baseURL + resourcePath
170193
}
171194

172195
// GetEffectiveHostAndScheme returns the effective host and scheme for a request.
173-
// It checks X-Forwarded-Host and X-Forwarded-Proto headers first (set by proxies),
174-
// then falls back to the request's Host and TLS state.
175-
func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive // parameters are required by http.oauth.BuildResourceMetadataURL signature
176-
// Check for forwarded headers first (typically set by reverse proxies)
177-
if forwardedHost := r.Header.Get(headers.ForwardedHostHeader); forwardedHost != "" {
178-
host = forwardedHost
196+
func GetEffectiveHostAndScheme(r *http.Request, cfg *Config) (host, scheme string) { //nolint:revive
197+
if fh := r.Header.Get(headers.ForwardedHostHeader); fh != "" {
198+
host = fh
179199
} else {
180200
host = r.Host
181201
}
182-
183-
// Determine scheme
184-
switch {
185-
case r.Header.Get(headers.ForwardedProtoHeader) != "":
186-
scheme = strings.ToLower(r.Header.Get(headers.ForwardedProtoHeader))
187-
case r.TLS != nil:
188-
scheme = "https"
189-
default:
190-
// Default to HTTPS in production scenarios
191-
scheme = "https"
202+
if host == "" {
203+
host = "localhost"
192204
}
193-
194-
return host, scheme
205+
if fp := r.Header.Get(headers.ForwardedProtoHeader); fp != "" {
206+
scheme = strings.ToLower(fp)
207+
} else {
208+
scheme = "https" // Default to HTTPS
209+
}
210+
return
195211
}
196212

197213
// BuildResourceMetadataURL constructs the full URL to the OAuth protected resource metadata endpoint.
198214
func BuildResourceMetadataURL(r *http.Request, cfg *Config, resourcePath string) string {
199215
host, scheme := GetEffectiveHostAndScheme(r, cfg)
200-
216+
suffix := ""
217+
if resourcePath != "" && resourcePath != "/" {
218+
if !strings.HasPrefix(resourcePath, "/") {
219+
suffix = "/" + resourcePath
220+
} else {
221+
suffix = resourcePath
222+
}
223+
}
201224
if cfg != nil && cfg.BaseURL != "" {
202-
baseURL := strings.TrimSuffix(cfg.BaseURL, "/")
203-
return baseURL + OAuthProtectedResourcePrefix + "/" + strings.TrimPrefix(resourcePath, "/")
225+
return strings.TrimSuffix(cfg.BaseURL, "/") + OAuthProtectedResourcePrefix + suffix
204226
}
227+
return fmt.Sprintf("%s://%s%s%s", scheme, host, OAuthProtectedResourcePrefix, suffix)
228+
}
205229

206-
path := OAuthProtectedResourcePrefix
207-
if resourcePath != "" && resourcePath != "/" {
208-
path = path + "/" + strings.TrimPrefix(resourcePath, "/")
230+
func normalizeBasePath(path string) string {
231+
trimmed := strings.TrimSpace(path)
232+
if trimmed == "" || trimmed == "/" {
233+
return ""
209234
}
235+
if !strings.HasPrefix(trimmed, "/") {
236+
trimmed = "/" + trimmed
237+
}
238+
return strings.TrimSuffix(trimmed, "/")
239+
}
210240

211-
return fmt.Sprintf("%s://%s%s", scheme, host, path)
241+
func joinRoute(basePath, pattern string) string {
242+
if basePath == "" {
243+
return pattern
244+
}
245+
if pattern == "" {
246+
return basePath
247+
}
248+
if strings.HasSuffix(basePath, "/") {
249+
return strings.TrimSuffix(basePath, "/") + pattern
250+
}
251+
return basePath + pattern
212252
}

0 commit comments

Comments
 (0)