Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import (
)

func main() {
proxy, err := proxy.NewOAuthProxy()
// Load configuration from environment variables
config := proxy.LoadConfigFromEnv()

proxy, err := proxy.NewOAuthProxy(config)
if err != nil {
log.Fatalf("Failed to create OAuth proxy: %v", err)
}
Expand Down
9 changes: 6 additions & 3 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func TestIntegrationFlow(t *testing.T) {
}()

// Create OAuth proxy
oauthProxy, err := proxy.NewOAuthProxy()
config := proxy.LoadConfigFromEnv()
oauthProxy, err := proxy.NewOAuthProxy(config)
if err != nil {
t.Skipf("Skipping test due to database connection error: %v", err)
}
Expand Down Expand Up @@ -164,7 +165,8 @@ func TestOAuthProxyCreation(t *testing.T) {
}()

// Create OAuth proxy
oauthProxy, err := proxy.NewOAuthProxy()
config := proxy.LoadConfigFromEnv()
oauthProxy, err := proxy.NewOAuthProxy(config)
require.NoError(t, err, "Should be able to create OAuth proxy with valid environment")
require.NotNil(t, oauthProxy, "OAuth proxy should not be nil")

Expand Down Expand Up @@ -213,7 +215,8 @@ func TestOAuthProxyStart(t *testing.T) {
}()

// Create OAuth proxy
oauthProxy, err := proxy.NewOAuthProxy()
config := proxy.LoadConfigFromEnv()
oauthProxy, err := proxy.NewOAuthProxy(config)
require.NoError(t, err)
defer func() {
_ = oauthProxy.Close()
Expand Down
45 changes: 45 additions & 0 deletions pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func (d *Store) setupSchema() error {
&types.Grant{},
&types.AuthorizationCode{},
&types.TokenData{},
&types.StoredAuthRequest{},
)
if err != nil {
return fmt.Errorf("failed to auto-migrate database schema: %w", err)
Expand Down Expand Up @@ -334,6 +335,50 @@ func (d *Store) CleanupExpiredTokens() error {
fmt.Printf("Deleted %d expired grants\n", result.RowsAffected)
}

// Delete expired auth requests
if err := d.CleanupExpiredAuthRequests(); err != nil {
return fmt.Errorf("failed to cleanup expired auth requests: %w", err)
}

return nil
}

// StoreAuthRequest stores an authorization request with a 15-minute TTL
func (d *Store) StoreAuthRequest(key string, data map[string]interface{}) error {
authRequest := &types.StoredAuthRequest{
Key: key,
Data: types.JSON(data),
ExpiresAt: time.Now().Add(15 * time.Minute), // 15-minute TTL
}
return d.db.Create(authRequest).Error
}

// GetAuthRequest retrieves an authorization request by key and checks TTL
func (d *Store) GetAuthRequest(key string) (map[string]interface{}, error) {
var authRequest types.StoredAuthRequest
err := d.db.First(&authRequest, "key = ? AND expires_at > ?", key, time.Now()).Error
if err != nil {
return nil, err
}

// Convert JSON back to map
return map[string]interface{}(authRequest.Data), nil
}

// DeleteAuthRequest deletes an authorization request by key
func (d *Store) DeleteAuthRequest(key string) error {
return d.db.Delete(&types.StoredAuthRequest{}, "key = ?", key).Error
}

// CleanupExpiredAuthRequests removes expired authorization requests
func (d *Store) CleanupExpiredAuthRequests() error {
result := d.db.Where("expires_at < ?", time.Now()).Delete(&types.StoredAuthRequest{})
if result.Error != nil {
return fmt.Errorf("failed to cleanup expired auth requests: %w", result.Error)
}
if result.RowsAffected > 0 {
fmt.Printf("Deleted %d expired auth requests\n", result.RowsAffected)
}
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/encryption/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ func GenerateRandomString(length int) string {
if _, err := rand.Read(bytes); err != nil {
panic(fmt.Errorf("failed to generate random string: %w", err))
}
return base64.URLEncoding.EncodeToString(bytes)
return base64.RawStdEncoding.EncodeToString(bytes)
}
41 changes: 26 additions & 15 deletions pkg/oauth/authorize/authorize.go
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
package authorize

import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"strings"

"github.com/obot-platform/mcp-oauth-proxy/pkg/encryption"
"github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils"
"github.com/obot-platform/mcp-oauth-proxy/pkg/providers"
"github.com/obot-platform/mcp-oauth-proxy/pkg/types"
)

type AuthorizationStore interface {
GetClient(clientID string) (*types.ClientInfo, error)
StoreAuthRequest(key string, data map[string]interface{}) error
}

type Handler struct {
db AuthorizationStore
provider providers.Provider
scopesSupported []string
clientID string
clientSecret string
}

func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string) http.Handler {
func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret string) http.Handler {
return &Handler{
db: db,
provider: provider,
scopesSupported: scopesSupported,
clientID: clientID,
clientSecret: clientSecret,
}
}

Expand Down Expand Up @@ -107,37 +110,45 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// Get the provider's client ID and secret
clientID := os.Getenv("OAUTH_CLIENT_ID")
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")

// Check if provider is configured
if clientID == "" || clientSecret == "" {
if p.clientID == "" || p.clientSecret == "" {
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Error: "invalid_request",
ErrorDescription: "OAuth provider not configured",
})
return
}

stateData, err := json.Marshal(authReq)
if err != nil {
// Generate a random state key
stateKey := encryption.GenerateRandomString(32)

// Store the auth request data in the database
authData := map[string]interface{}{
"response_type": authReq.ResponseType,
"client_id": authReq.ClientID,
"redirect_uri": authReq.RedirectURI,
"scope": authReq.Scope,
"state": authReq.State,
"code_challenge": authReq.CodeChallenge,
"code_challenge_method": authReq.CodeChallengeMethod,
}

if err := p.db.StoreAuthRequest(stateKey, authData); err != nil {
handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{
Error: "server_error",
ErrorDescription: "Failed to marshal state data",
ErrorDescription: "Failed to store authorization request",
})
return
}

encodedState := base64.URLEncoding.EncodeToString(stateData)
redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r))

// Generate authorization URL with the provider
authURL := p.provider.GetAuthorizationURL(
clientID,
p.clientID,
redirectURI,
authReq.Scope,
encodedState,
stateKey,
)

// Redirect to the provider's authorization URL
Expand Down
103 changes: 71 additions & 32 deletions pkg/oauth/callback/callback.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package callback

import (
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"

Expand All @@ -20,22 +17,48 @@ import (
type Store interface {
StoreGrant(grant *types.Grant) error
StoreAuthCode(code, grantID, userID string) error
GetAuthRequest(key string) (map[string]interface{}, error)
DeleteAuthRequest(key string) error
}

type Handler struct {
db Store
provider providers.Provider
encryptionKey []byte
clientID string
clientSecret string
}

func NewHandler(db Store, provider providers.Provider, encryptionKey []byte) http.Handler {
func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string) http.Handler {
return &Handler{
db: db,
provider: provider,
encryptionKey: encryptionKey,
clientID: clientID,
clientSecret: clientSecret,
}
}

// scopeContainsProfileOrEmail checks if the given scopes contain profile or email
func (p *Handler) scopeContainsProfileOrEmail(scopes []string) bool {
for _, scope := range scopes {
if scope == "profile" || scope == "email" {
return true
}
}
return false
}

// getStringFromMap safely extracts a string value from a map[string]interface{}
func getStringFromMap(data map[string]interface{}, key string) string {
if value, ok := data[key]; ok {
if str, ok := value.(string); ok {
return str
}
}
return ""
}

func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Handle OAuth callback from external providers
code := r.URL.Query().Get("code")
Expand All @@ -61,30 +84,39 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

var authReq types.AuthRequest
stateData, err := base64.URLEncoding.DecodeString(state)
// Retrieve auth request data from database using state as key
authData, err := p.db.GetAuthRequest(state)
if err != nil {
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Error: "invalid_request",
ErrorDescription: "Invalid state parameter",
ErrorDescription: "Invalid or expired state parameter",
})
return
}
if err := json.Unmarshal(stateData, &authReq); err != nil {
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Error: "invalid_request",
ErrorDescription: "Invalid state parameter",
})
return

// Convert auth data back to AuthRequest struct
authReq := types.AuthRequest{
ResponseType: getStringFromMap(authData, "response_type"),
ClientID: getStringFromMap(authData, "client_id"),
RedirectURI: getStringFromMap(authData, "redirect_uri"),
Scope: getStringFromMap(authData, "scope"),
State: getStringFromMap(authData, "state"),
CodeChallenge: getStringFromMap(authData, "code_challenge"),
CodeChallengeMethod: getStringFromMap(authData, "code_challenge_method"),
}

// Clean up the auth request data after successful retrieval
defer func() {
if err := p.db.DeleteAuthRequest(state); err != nil {
log.Printf("Failed to delete auth request: %v", err)
}
}()

// Get provider credentials
clientID := os.Getenv("OAUTH_CLIENT_ID")
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r))

// Exchange code for tokens
tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, clientID, clientSecret, redirectURI)
tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, p.clientID, p.clientSecret, redirectURI)
if err != nil {
log.Printf("Failed to exchange code for token: %v", err)
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Expand All @@ -94,15 +126,22 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

// Get user info from the provider
userInfo, err := p.provider.GetUserInfo(r.Context(), tokenInfo.AccessToken)
if err != nil {
log.Printf("Failed to get user info: %v", err)
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Error: "invalid_grant",
ErrorDescription: "Failed to get user information",
})
return
// Check if scope includes profile or email before getting user info
scopes := strings.Fields(authReq.Scope)
needsUserInfo := p.scopeContainsProfileOrEmail(scopes)

userInfo := &providers.UserInfo{}
if needsUserInfo {
// Get user info from the provider
userInfo, err = p.provider.GetUserInfo(r.Context(), tokenInfo.AccessToken)
if err != nil {
log.Printf("Failed to get user info: %v", err)
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
Error: "invalid_grant",
ErrorDescription: "Failed to get user information",
})
return
}
}

// Create a grant for this user
Expand All @@ -111,13 +150,17 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// Prepare sensitive props data
sensitiveProps := map[string]interface{}{
"email": userInfo.Email,
"name": userInfo.Name,
"access_token": tokenInfo.AccessToken,
"refresh_token": tokenInfo.RefreshToken,
"expires_at": tokenInfo.ExpireAt,
}

// Only add user info if we have it
if needsUserInfo {
sensitiveProps["email"] = userInfo.Email
sensitiveProps["name"] = userInfo.Name
}

// Initialize props map
props := make(map[string]interface{})

Expand All @@ -138,17 +181,13 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
props["algorithm"] = encryptedProps.Algorithm
props["encrypted"] = true

// Add non-sensitive data
props["user_id"] = userInfo.ID

grant := &types.Grant{
ID: grantID,
ClientID: authReq.ClientID,
UserID: userInfo.ID,
Scope: strings.Fields(authReq.Scope),
Scope: scopes,
Metadata: map[string]interface{}{
"provider": p.provider,
"label": userInfo.Name,
},
Props: props,
CreatedAt: now,
Expand Down
Loading
Loading