Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ type ServerOption func(*serverConfig)

// serverConfig holds the server configuration
type serverConfig struct {
middlewares []func(http.Handler) http.Handler
middlewares []func(http.Handler) http.Handler
authInfoHandler http.Handler
}

// WithMiddlewares adds middleware to the server
Expand All @@ -30,6 +31,13 @@ func WithMiddlewares(mw ...func(http.Handler) http.Handler) ServerOption {
}
}

// WithAuthInfoHandler sets the auth info handler to be mounted at /.well-known/oauth-protected-resource
func WithAuthInfoHandler(handler http.Handler) ServerOption {
return func(cfg *serverConfig) {
cfg.authInfoHandler = handler
}
}

// NewServer creates and configures the HTTP router with the given service and options
func NewServer(svc service.RegistryService, opts ...ServerOption) *chi.Mux {
// Initialize configuration with defaults
Expand All @@ -55,6 +63,11 @@ func NewServer(svc service.RegistryService, opts ...ServerOption) *chi.Mux {
// Mount OpenAPI endpoint
r.Get("/openapi.json", openAPIHandler)

// Mount auth info handler at well-known endpoint (if configured)
if cfg.authInfoHandler != nil {
r.Handle("/.well-known/oauth-protected-resource", cfg.authInfoHandler)
}

// Mount MCP Registry API v0 compatible routes
r.Mount("/registry", v01.Router(svc))
r.Mount("/extension/v0", extensionv0.Router(svc))
Expand Down
27 changes: 26 additions & 1 deletion internal/app/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stacklok/toolhive/pkg/logger"

"github.com/stacklok/toolhive-registry-server/internal/api"
"github.com/stacklok/toolhive-registry-server/internal/auth"
"github.com/stacklok/toolhive-registry-server/internal/config"
"github.com/stacklok/toolhive-registry-server/internal/service"
database "github.com/stacklok/toolhive-registry-server/internal/service/db"
Expand All @@ -36,6 +37,9 @@ const (
defaultIdleTimeout = 60 * time.Second
)

// defaultPublicPaths are paths that never require authentication
var defaultPublicPaths = []string{"/health", "/docs", "/swagger", "/.well-known"}

// RegistryAppOptions is a function that configures the registry app builder
type RegistryAppOptions func(*registryAppConfig) error

Expand Down Expand Up @@ -63,6 +67,10 @@ type registryAppConfig struct {
dataDir string
registryFile string
statusFile string

// Auth components
authMiddleware func(http.Handler) http.Handler
authInfoHandler http.Handler
}

func baseConfig(opts ...RegistryAppOptions) (*registryAppConfig, error) {
Expand Down Expand Up @@ -109,6 +117,15 @@ func NewRegistryApp(
return nil, fmt.Errorf("failed to build service components: %w", err)
}

// Build auth middleware (if not injected)
if cfg.authMiddleware == nil {
var authErr error
cfg.authMiddleware, cfg.authInfoHandler, authErr = auth.NewAuthMiddleware(ctx, cfg.config.Auth, auth.DefaultValidatorFactory)
if authErr != nil {
return nil, fmt.Errorf("failed to build auth middleware: %w", authErr)
}
}

// Build HTTP server
httpServer, err := buildHTTPServer(ctx, cfg, registryService)
if err != nil {
Expand Down Expand Up @@ -405,8 +422,16 @@ func buildHTTPServer(
}
}

// Create auth middleware that bypasses public paths
publicPaths := defaultPublicPaths
if b.config != nil && b.config.Auth != nil && len(b.config.Auth.PublicPaths) > 0 {
publicPaths = append(publicPaths, b.config.Auth.PublicPaths...)
}
authMw := auth.WrapWithPublicPaths(b.authMiddleware, publicPaths)
b.middlewares = append(b.middlewares, authMw)

// Create router with middlewares
router := api.NewServer(svc, api.WithMiddlewares(b.middlewares...))
router := api.NewServer(svc, api.WithMiddlewares(b.middlewares...), api.WithAuthInfoHandler(b.authInfoHandler))

// Create HTTP server
server := &http.Server{
Expand Down
8 changes: 6 additions & 2 deletions internal/app/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ func TestBuildHTTPServer(t *testing.T) {
mockSvc := mocks.NewMockRegistryService(ctrl)
tt.setupMock(mockSvc)

// Set auth middleware in config for tests
tt.config.authMiddleware = func(next http.Handler) http.Handler { return next }
tt.config.authInfoHandler = nil
server, err := buildHTTPServer(ctx, tt.config, mockSvc)

require.NoError(t, err)
Expand All @@ -323,11 +326,12 @@ func TestBuildHTTPServer(t *testing.T) {
assert.NotNil(t, server.Handler)

// Verify middlewares were set
// Note: auth middleware is always appended, so counts are +1
if tt.expectDefaults {
assert.NotNil(t, tt.config.middlewares)
assert.Greater(t, len(tt.config.middlewares), 0, "default middlewares should be set")
} else {
assert.Equal(t, 1, len(tt.config.middlewares), "custom middlewares should be preserved")
assert.Equal(t, 2, len(tt.config.middlewares), "custom middlewares should be preserved plus auth middleware")
}
})
}
Expand Down Expand Up @@ -420,7 +424,7 @@ func TestBuildServiceComponents(t *testing.T) {
},
},
{
name: "success with pre-set registryProvider and deploymentProvider",
name: "success with pre-set registryProvider",
config: &registryAppConfig{
config: createValidTestConfig(),
},
Expand Down
21 changes: 21 additions & 0 deletions internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,24 @@ func (m *MultiProviderMiddleware) writeError(w http.ResponseWriter, status int,
logger.Errorf("auth: failed to encode error response: %v", err)
}
}

// WrapWithPublicPaths wraps an auth middleware to bypass authentication for public paths.
// It checks each request path against the provided list of public paths using IsPublicPath.
// Requests to public paths are passed directly to the next handler without authentication,
// while all other requests go through the provided auth middleware.
func WrapWithPublicPaths(
authMw func(http.Handler) http.Handler,
publicPaths []string,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
// Pre-wrap the handler once during initialization, not per-request
authWrappedNext := authMw(next)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !IsPublicPath(r.URL.Path, publicPaths) {
authWrappedNext.ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r)
}
})
}
}
46 changes: 46 additions & 0 deletions internal/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,49 @@ func TestMultiProviderMiddleware_WWWAuthenticate(t *testing.T) {
})
}
}

func TestWrapWithPublicPaths(t *testing.T) {
t.Parallel()

tests := []struct {
name string
path string
publicPaths []string
expectAuthCall bool
}{
// Public paths bypass auth
{"exact public path bypasses auth", "/health", []string{"/health"}, false},
{"sub-path of public bypasses auth", "/health/check", []string{"/health"}, false},
{"well-known bypasses auth", "/.well-known/oauth", []string{"/.well-known"}, false},

// Protected paths require auth
{"protected path requires auth", "/v0/servers", []string{"/health"}, true},
{"similar prefix still requires auth", "/healthcheck", []string{"/health"}, true},
{"empty public paths requires auth", "/health", []string{}, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

authCalled := false
mockAuthMw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authCalled = true
next.ServeHTTP(w, r)
})
}

mw := WrapWithPublicPaths(mockAuthMw, tt.publicPaths)
handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req, _ := http.NewRequest(http.MethodGet, tt.path, nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, tt.expectAuthCall, authCalled)
})
}
}
Loading