Skip to content

Commit a56a5b5

Browse files
mattdhollowayCopilotomgitsads
authored
OAuth metadata implementation (#1862)
* initial oauth metadata implementation * add nolint for GetEffectiveHostAndScheme * remove CAPI reference * remove nonsensical example URL * anonymize * add oauth tests * replace custom protected resource metadata handler with our own * remove unused header * Update pkg/http/oauth/oauth.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * pass oauth config to mcp handler for token extraction * chore: retrigger ci * align types with base branch * update more types * initial oauth metadata implementation * add nolint for GetEffectiveHostAndScheme * remove CAPI reference * remove nonsensical example URL * anonymize * add oauth tests * replace custom protected resource metadata handler with our own * Update pkg/http/oauth/oauth.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * chore: retrigger ci * update more types * remove CAPI specific header * restore mcp path specific logic * implement better resource path handling for OAuth server * return auth handler to lib version * rename to base-path flag * switch to chi group * make viper commands http only * Default to http, but check for TLS in GetEffectiveHostAndScheme --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Adam Holt <me@adamholt.co.uk>
1 parent a4d6b20 commit a56a5b5

File tree

8 files changed

+939
-9
lines changed

8 files changed

+939
-9
lines changed

cmd/github-mcp-server/main.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ var (
101101
Version: version,
102102
Host: viper.GetString("host"),
103103
Port: viper.GetInt("port"),
104+
BaseURL: viper.GetString("base-url"),
105+
ResourcePath: viper.GetString("base-path"),
104106
ExportTranslations: viper.GetBool("export-translations"),
105107
EnableCommandLogging: viper.GetBool("enable-command-logging"),
106108
LogFilePath: viper.GetString("log-file"),
@@ -134,7 +136,11 @@ func init() {
134136
rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode")
135137
rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features")
136138
rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)")
137-
rootCmd.PersistentFlags().Int("port", 8082, "HTTP server port")
139+
140+
// HTTP-specific flags
141+
httpCmd.Flags().Int("port", 8082, "HTTP server port")
142+
httpCmd.Flags().String("base-url", "", "Base URL where this server is publicly accessible (for OAuth resource metadata)")
143+
httpCmd.Flags().String("base-path", "", "Externally visible base path for the HTTP server (for OAuth resource metadata)")
138144

139145
// Bind flag to viper
140146
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
@@ -150,7 +156,9 @@ func init() {
150156
_ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode"))
151157
_ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders"))
152158
_ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl"))
153-
_ = viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port"))
159+
_ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port"))
160+
_ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url"))
161+
_ = viper.BindPFlag("base-path", httpCmd.Flags().Lookup("base-path"))
154162

155163
// Add subcommands
156164
rootCmd.AddCommand(stdioCmd)

pkg/http/handler.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
ghcontext "github.com/github/github-mcp-server/pkg/context"
99
"github.com/github/github-mcp-server/pkg/github"
1010
"github.com/github/github-mcp-server/pkg/http/middleware"
11+
"github.com/github/github-mcp-server/pkg/http/oauth"
1112
"github.com/github/github-mcp-server/pkg/inventory"
1213
"github.com/github/github-mcp-server/pkg/translations"
1314
"github.com/go-chi/chi/v5"
@@ -25,11 +26,13 @@ type Handler struct {
2526
t translations.TranslationHelperFunc
2627
githubMcpServerFactory GitHubMCPServerFactoryFunc
2728
inventoryFactoryFunc InventoryFactoryFunc
29+
oauthCfg *oauth.Config
2830
}
2931

3032
type HandlerOptions struct {
3133
GitHubMcpServerFactory GitHubMCPServerFactoryFunc
3234
InventoryFactory InventoryFactoryFunc
35+
OAuthConfig *oauth.Config
3336
FeatureChecker inventory.FeatureFlagChecker
3437
}
3538

@@ -47,6 +50,12 @@ func WithInventoryFactory(f InventoryFactoryFunc) HandlerOption {
4750
}
4851
}
4952

53+
func WithOAuthConfig(cfg *oauth.Config) HandlerOption {
54+
return func(o *HandlerOptions) {
55+
o.OAuthConfig = cfg
56+
}
57+
}
58+
5059
func WithFeatureChecker(checker inventory.FeatureFlagChecker) HandlerOption {
5160
return func(o *HandlerOptions) {
5261
o.FeatureChecker = checker
@@ -83,14 +92,20 @@ func NewHTTPMcpHandler(
8392
t: t,
8493
githubMcpServerFactory: githubMcpServerFactory,
8594
inventoryFactoryFunc: inventoryFactory,
95+
oauthCfg: opts.OAuthConfig,
8696
}
8797
}
8898

99+
func (h *Handler) RegisterMiddleware(r chi.Router) {
100+
r.Use(
101+
middleware.ExtractUserToken(h.oauthCfg),
102+
middleware.WithRequestConfig,
103+
)
104+
}
105+
89106
// RegisterRoutes registers the routes for the MCP server
90107
// URL-based values take precedence over header-based values
91108
func (h *Handler) RegisterRoutes(r chi.Router) {
92-
r.Use(middleware.WithRequestConfig)
93-
94109
// Base routes
95110
r.Mount("/", h)
96111
r.With(withReadonly).Mount("/readonly", h)
@@ -154,7 +169,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
154169
Stateless: true,
155170
})
156171

157-
middleware.ExtractUserToken()(mcpHandler).ServeHTTP(w, r)
172+
mcpHandler.ServeHTTP(w, r)
158173
}
159174

160175
func DefaultGitHubMCPServerFactory(r *http.Request, deps github.ToolDependencies, inventory *inventory.Inventory, cfg *github.MCPServerConfig) (*mcp.Server, error) {

pkg/http/handler_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
ghcontext "github.com/github/github-mcp-server/pkg/context"
1212
"github.com/github/github-mcp-server/pkg/github"
1313
"github.com/github/github-mcp-server/pkg/http/headers"
14+
"github.com/github/github-mcp-server/pkg/http/middleware"
1415
"github.com/github/github-mcp-server/pkg/inventory"
1516
"github.com/github/github-mcp-server/pkg/translations"
1617
"github.com/go-chi/chi/v5"
@@ -294,6 +295,7 @@ func TestHTTPHandlerRoutes(t *testing.T) {
294295

295296
// Create router and register routes
296297
r := chi.NewRouter()
298+
r.Use(middleware.WithRequestConfig)
297299
handler.RegisterRoutes(r)
298300

299301
// Create request

pkg/http/headers/headers.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ const (
2121
// RealIPHeader is a standard HTTP Header used to indicate the real IP address of the client.
2222
RealIPHeader = "X-Real-IP"
2323

24+
// ForwardedHostHeader is a standard HTTP Header for preserving the original Host header when proxying.
25+
ForwardedHostHeader = "X-Forwarded-Host"
26+
// ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying.
27+
ForwardedProtoHeader = "X-Forwarded-Proto"
28+
2429
// RequestHmacHeader is used to authenticate requests to the Raw API.
2530
RequestHmacHeader = "Request-Hmac"
2631

pkg/http/middleware/token.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
ghcontext "github.com/github/github-mcp-server/pkg/context"
1111
httpheaders "github.com/github/github-mcp-server/pkg/http/headers"
1212
"github.com/github/github-mcp-server/pkg/http/mark"
13+
"github.com/github/github-mcp-server/pkg/http/oauth"
1314
)
1415

1516
type authType int
@@ -39,14 +40,14 @@ var supportedThirdPartyTokenPrefixes = []string{
3940
// were 40 characters long and only contained the characters a-f and 0-9.
4041
var oldPatternRegexp = regexp.MustCompile(`\A[a-f0-9]{40}\z`)
4142

42-
func ExtractUserToken() func(next http.Handler) http.Handler {
43+
func ExtractUserToken(oauthCfg *oauth.Config) func(next http.Handler) http.Handler {
4344
return func(next http.Handler) http.Handler {
4445
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4546
_, token, err := parseAuthorizationHeader(r)
4647
if err != nil {
4748
// For missing Authorization header, return 401 with WWW-Authenticate header per MCP spec
4849
if errors.Is(err, errMissingAuthorizationHeader) {
49-
// sendAuthChallenge(w, r, cfg, obsv)
50+
sendAuthChallenge(w, r, oauthCfg)
5051
return
5152
}
5253
// For other auth errors (bad format, unsupported), return 400
@@ -62,6 +63,16 @@ func ExtractUserToken() func(next http.Handler) http.Handler {
6263
})
6364
}
6465
}
66+
67+
// sendAuthChallenge sends a 401 Unauthorized response with WWW-Authenticate header
68+
// containing the OAuth protected resource metadata URL as per RFC 6750 and MCP spec.
69+
func sendAuthChallenge(w http.ResponseWriter, r *http.Request, oauthCfg *oauth.Config) {
70+
resourcePath := oauth.ResolveResourcePath(r, oauthCfg)
71+
resourceMetadataURL := oauth.BuildResourceMetadataURL(r, oauthCfg, resourcePath)
72+
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer resource_metadata=%q`, resourceMetadataURL))
73+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
74+
}
75+
6576
func parseAuthorizationHeader(req *http.Request) (authType authType, token string, _ error) {
6677
authHeader := req.Header.Get(httpheaders.AuthorizationHeader)
6778
if authHeader == "" {

0 commit comments

Comments
 (0)