Skip to content

Commit 50227bf

Browse files
committed
replace custom protected resource metadata handler with our own
1 parent 9b5c2fb commit 50227bf

File tree

3 files changed

+55
-86
lines changed

3 files changed

+55
-86
lines changed

pkg/http/oauth/oauth.go

Lines changed: 26 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
package oauth
44

55
import (
6-
"bytes"
7-
_ "embed"
86
"fmt"
9-
"html"
107
"net/http"
118
"net/url"
129
"strings"
13-
"text/template"
10+
11+
"github.com/modelcontextprotocol/go-sdk/auth"
1412

1513
"github.com/github/github-mcp-server/pkg/http/headers"
1614
"github.com/go-chi/chi/v5"
15+
"github.com/modelcontextprotocol/go-sdk/oauthex"
1716
)
1817

1918
const (
@@ -24,9 +23,6 @@ const (
2423
DefaultAuthorizationServer = "https://github.com/login/oauth"
2524
)
2625

27-
//go:embed protected_resource.json.tmpl
28-
var protectedResourceTemplate []byte
29-
3026
// SupportedScopes lists all OAuth scopes that may be required by MCP tools.
3127
var SupportedScopes = []string{
3228
"repo",
@@ -66,8 +62,7 @@ type ProtectedResourceData struct {
6662

6763
// AuthHandler handles OAuth-related HTTP endpoints.
6864
type AuthHandler struct {
69-
cfg *Config
70-
protectedResourceTemplate *template.Template
65+
cfg *Config
7166
}
7267

7368
// NewAuthHandler creates a new OAuth auth handler.
@@ -81,21 +76,16 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
8176
cfg.AuthorizationServer = DefaultAuthorizationServer
8277
}
8378

84-
tmpl, err := template.New("protected-resource").Parse(string(protectedResourceTemplate))
85-
if err != nil {
86-
return nil, fmt.Errorf("failed to parse protected resource template: %w", err)
87-
}
88-
8979
return &AuthHandler{
90-
cfg: cfg,
91-
protectedResourceTemplate: tmpl,
80+
cfg: cfg,
9281
}, nil
9382
}
9483

9584
// routePatterns defines the route patterns for OAuth protected resource metadata.
9685
var routePatterns = []string{
9786
"", // Root: /.well-known/oauth-protected-resource
9887
"/readonly", // Read-only mode
88+
"/insiders", // Insiders mode
9989
"/x/{toolset}",
10090
"/x/{toolset}/readonly",
10191
}
@@ -105,12 +95,30 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) {
10595
for _, pattern := range routePatterns {
10696
for _, route := range h.routesForPattern(pattern) {
10797
path := OAuthProtectedResourcePrefix + route
108-
r.Get(path, h.handleProtectedResource)
109-
r.Options(path, h.handleProtectedResource) // CORS support
98+
99+
// Build metadata for this specific resource path
100+
metadata := h.buildMetadata(route)
101+
r.Handle(path, auth.ProtectedResourceMetadataHandler(metadata))
110102
}
111103
}
112104
}
113105

106+
func (h *AuthHandler) buildMetadata(resourcePath string) *oauthex.ProtectedResourceMetadata {
107+
baseURL := strings.TrimSuffix(h.cfg.BaseURL, "/")
108+
resourceURL := baseURL
109+
if resourcePath != "" && resourcePath != "/" {
110+
resourceURL = baseURL + resourcePath
111+
}
112+
113+
return &oauthex.ProtectedResourceMetadata{
114+
Resource: resourceURL,
115+
AuthorizationServers: []string{h.cfg.AuthorizationServer},
116+
ResourceName: "GitHub MCP Server",
117+
ScopesSupported: SupportedScopes,
118+
BearerMethodsSupported: []string{"header"},
119+
}
120+
}
121+
114122
// routesForPattern generates route variants for a given pattern.
115123
// GitHub strips the /mcp prefix before forwarding, so we register both variants:
116124
// - With /mcp prefix: for direct access or when GitHub doesn't strip
@@ -124,37 +132,6 @@ func (h *AuthHandler) routesForPattern(pattern string) []string {
124132
}
125133
}
126134

127-
// handleProtectedResource handles requests for OAuth protected resource metadata.
128-
func (h *AuthHandler) handleProtectedResource(w http.ResponseWriter, r *http.Request) {
129-
// Extract the resource path from the URL
130-
resourcePath := strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix)
131-
if resourcePath == "" || resourcePath == "/" {
132-
resourcePath = "/"
133-
} else {
134-
resourcePath = strings.TrimPrefix(resourcePath, "/")
135-
}
136-
137-
data, err := h.GetProtectedResourceData(r, html.EscapeString(resourcePath))
138-
if err != nil {
139-
http.Error(w, err.Error(), http.StatusBadRequest)
140-
return
141-
}
142-
143-
var buf bytes.Buffer
144-
if err := h.protectedResourceTemplate.Execute(&buf, data); err != nil {
145-
http.Error(w, "Internal server error", http.StatusInternalServerError)
146-
return
147-
}
148-
149-
// Set CORS headers
150-
w.Header().Set("Access-Control-Allow-Origin", "*")
151-
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
152-
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
153-
w.Header().Set("Content-Type", "application/json")
154-
w.WriteHeader(http.StatusOK)
155-
_, _ = w.Write(buf.Bytes())
156-
}
157-
158135
// GetEffectiveResourcePath returns the resource path for OAuth protected resource URLs.
159136
// It checks for the X-GitHub-Original-Path header set by GitHub, which contains
160137
// the exact path the client requested before the /mcp prefix was stripped.

pkg/http/oauth/oauth_test.go

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ func TestNewAuthHandler(t *testing.T) {
6262
require.NotNil(t, handler)
6363

6464
assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer)
65-
assert.NotNil(t, handler.protectedResourceTemplate)
6665
})
6766
}
6867
}
@@ -444,8 +443,10 @@ func TestHandleProtectedResource(t *testing.T) {
444443
validateResponse func(t *testing.T, body map[string]any)
445444
}{
446445
{
447-
name: "GET request returns protected resource metadata",
448-
cfg: &Config{},
446+
name: "GET request returns protected resource metadata",
447+
cfg: &Config{
448+
BaseURL: "https://api.example.com",
449+
},
449450
path: OAuthProtectedResourcePrefix,
450451
host: "api.example.com",
451452
method: http.MethodGet,
@@ -454,7 +455,7 @@ func TestHandleProtectedResource(t *testing.T) {
454455
validateResponse: func(t *testing.T, body map[string]any) {
455456
t.Helper()
456457
assert.Equal(t, "GitHub MCP Server", body["resource_name"])
457-
assert.Contains(t, body["resource"], "api.example.com")
458+
assert.Equal(t, "https://api.example.com", body["resource"])
458459

459460
authServers, ok := body["authorization_servers"].([]any)
460461
require.True(t, ok)
@@ -463,40 +464,47 @@ func TestHandleProtectedResource(t *testing.T) {
463464
},
464465
},
465466
{
466-
name: "OPTIONS request for CORS",
467-
cfg: &Config{},
467+
name: "OPTIONS request for CORS preflight",
468+
cfg: &Config{
469+
BaseURL: "https://api.example.com",
470+
},
468471
path: OAuthProtectedResourcePrefix,
469472
host: "api.example.com",
470473
method: http.MethodOptions,
471-
expectedStatusCode: http.StatusOK,
474+
expectedStatusCode: http.StatusNoContent,
472475
},
473476
{
474-
name: "path with /mcp suffix",
475-
cfg: &Config{},
477+
name: "path with /mcp suffix",
478+
cfg: &Config{
479+
BaseURL: "https://api.example.com",
480+
},
476481
path: OAuthProtectedResourcePrefix + "/mcp",
477482
host: "api.example.com",
478483
method: http.MethodGet,
479484
expectedStatusCode: http.StatusOK,
480485
validateResponse: func(t *testing.T, body map[string]any) {
481486
t.Helper()
482-
assert.Contains(t, body["resource"], "/mcp")
487+
assert.Equal(t, "https://api.example.com/mcp", body["resource"])
483488
},
484489
},
485490
{
486-
name: "path with /readonly suffix",
487-
cfg: &Config{},
491+
name: "path with /readonly suffix",
492+
cfg: &Config{
493+
BaseURL: "https://api.example.com",
494+
},
488495
path: OAuthProtectedResourcePrefix + "/readonly",
489496
host: "api.example.com",
490497
method: http.MethodGet,
491498
expectedStatusCode: http.StatusOK,
492499
validateResponse: func(t *testing.T, body map[string]any) {
493500
t.Helper()
494-
assert.Contains(t, body["resource"], "/readonly")
501+
assert.Equal(t, "https://api.example.com/readonly", body["resource"])
495502
},
496503
},
497504
{
498505
name: "custom authorization server in response",
499506
cfg: &Config{
507+
BaseURL: "https://api.example.com",
500508
AuthorizationServer: "https://custom.auth.example.com/oauth",
501509
},
502510
path: OAuthProtectedResourcePrefix,
@@ -559,7 +567,9 @@ func TestHandleProtectedResource(t *testing.T) {
559567
func TestRegisterRoutes(t *testing.T) {
560568
t.Parallel()
561569

562-
handler, err := NewAuthHandler(&Config{})
570+
handler, err := NewAuthHandler(&Config{
571+
BaseURL: "https://api.example.com",
572+
})
563573
require.NoError(t, err)
564574

565575
router := chi.NewRouter()
@@ -588,12 +598,12 @@ func TestRegisterRoutes(t *testing.T) {
588598
router.ServeHTTP(rec, req)
589599
assert.Equal(t, http.StatusOK, rec.Code, "GET %s should return 200", route)
590600

591-
// Test OPTIONS (CORS)
601+
// Test OPTIONS (CORS preflight)
592602
req = httptest.NewRequest(http.MethodOptions, route, nil)
593603
req.Host = "api.example.com"
594604
rec = httptest.NewRecorder()
595605
router.ServeHTTP(rec, req)
596-
assert.Equal(t, http.StatusOK, rec.Code, "OPTIONS %s should return 200", route)
606+
assert.Equal(t, http.StatusNoContent, rec.Code, "OPTIONS %s should return 204", route)
597607
})
598608
}
599609
}
@@ -623,7 +633,9 @@ func TestSupportedScopes(t *testing.T) {
623633
func TestProtectedResourceResponseFormat(t *testing.T) {
624634
t.Parallel()
625635

626-
handler, err := NewAuthHandler(&Config{})
636+
handler, err := NewAuthHandler(&Config{
637+
BaseURL: "https://api.example.com",
638+
})
627639
require.NoError(t, err)
628640

629641
router := chi.NewRouter()

pkg/http/oauth/protected_resource.json.tmpl

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)