Skip to content

Commit 4943c8c

Browse files
authored
Merge pull request #2 from obot-platform/fix-oauth-refresh
Fix: properly store renewed access token into grant table
2 parents 363c91e + ae8ac67 commit 4943c8c

File tree

3 files changed

+141
-4
lines changed

3 files changed

+141
-4
lines changed

database/database.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,52 @@ func (d *Database) StoreGrant(grant *Grant) error {
499499
return err
500500
}
501501

502+
// UpdateGrant updates an existing grant's properties
503+
func (d *Database) UpdateGrant(grant *Grant) error {
504+
var query string
505+
if d.dbType == "postgres" {
506+
query = `
507+
UPDATE grants
508+
SET scope = $1, metadata = $2, props = $3, expires_at = $4
509+
WHERE id = $5 AND user_id = $6
510+
`
511+
} else {
512+
query = `
513+
UPDATE grants
514+
SET scope = ?, metadata = ?, props = ?, expires_at = ?
515+
WHERE id = ? AND user_id = ?
516+
`
517+
}
518+
519+
scope, _ := json.Marshal(grant.Scope)
520+
metadata, _ := json.Marshal(grant.Metadata)
521+
props, _ := json.Marshal(grant.Props)
522+
523+
result, err := d.db.Exec(query,
524+
scope,
525+
metadata,
526+
props,
527+
grant.ExpiresAt,
528+
grant.ID,
529+
grant.UserID,
530+
)
531+
if err != nil {
532+
return err
533+
}
534+
535+
// Check if any rows were affected
536+
rowsAffected, err := result.RowsAffected()
537+
if err != nil {
538+
return err
539+
}
540+
541+
if rowsAffected == 0 {
542+
return fmt.Errorf("grant not found: id=%s, user_id=%s", grant.ID, grant.UserID)
543+
}
544+
545+
return nil
546+
}
547+
502548
// GetGrant retrieves a grant by ID and user ID
503549
func (d *Database) GetGrant(grantID, userID string) (*Grant, error) {
504550
var query string

main.go

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,85 @@ func (p *OAuthProxy) decryptPropsIfNeeded(props map[string]interface{}) (map[str
327327
return result, nil
328328
}
329329

330+
// updateGrant updates a grant with new token information
331+
func (p *OAuthProxy) updateGrant(grantID, userID string, oldTokenInfo *tokens.TokenInfo, newTokenInfo *providers.TokenInfo) error {
332+
// Get the existing grant
333+
grant, err := p.db.GetGrant(grantID, userID)
334+
if err != nil {
335+
return fmt.Errorf("failed to get grant: %w", err)
336+
}
337+
338+
// Prepare sensitive props data
339+
sensitiveProps := map[string]interface{}{
340+
"access_token": newTokenInfo.AccessToken,
341+
"refresh_token": newTokenInfo.RefreshToken,
342+
"expires_at": newTokenInfo.ExpireAt,
343+
}
344+
345+
// Add existing user info if available
346+
if grant.Props != nil {
347+
if email, ok := grant.Props["email"].(string); ok {
348+
sensitiveProps["email"] = email
349+
}
350+
if name, ok := grant.Props["name"].(string); ok {
351+
sensitiveProps["name"] = name
352+
}
353+
if userID, ok := grant.Props["user_id"].(string); ok {
354+
sensitiveProps["user_id"] = userID
355+
}
356+
}
357+
358+
// use old refresh token in case new one is not provided
359+
if sensitiveProps["refresh_token"] == "" {
360+
sensitiveProps["refresh_token"] = oldTokenInfo.Props["refresh_token"]
361+
}
362+
363+
// Initialize props map
364+
props := make(map[string]interface{})
365+
366+
// Check if encryption is enabled
367+
if p.encryptionKey != "" {
368+
// Decode the encryption key from base64
369+
encryptionKey, err := base64.StdEncoding.DecodeString(p.encryptionKey)
370+
if err != nil {
371+
return fmt.Errorf("failed to decode encryption key: %w", err)
372+
}
373+
374+
// Validate key length (must be 32 bytes for AES-256)
375+
if len(encryptionKey) != 32 {
376+
return fmt.Errorf("invalid encryption key length: %d bytes (expected 32)", len(encryptionKey))
377+
}
378+
379+
// Encrypt the sensitive props data
380+
encryptedProps, err := encryptData(sensitiveProps, encryptionKey)
381+
if err != nil {
382+
return fmt.Errorf("failed to encrypt props data: %w", err)
383+
}
384+
385+
// Store encrypted data
386+
props["encrypted_data"] = encryptedProps.Data
387+
props["iv"] = encryptedProps.IV
388+
props["algorithm"] = encryptedProps.Algorithm
389+
props["encrypted"] = true
390+
} else {
391+
// Store data in plain text if no encryption key is provided
392+
for key, value := range sensitiveProps {
393+
props[key] = value
394+
}
395+
props["encrypted"] = false
396+
}
397+
398+
// Update the grant with new props
399+
grant.Props = props
400+
401+
// Update the grant in the database
402+
if err := p.db.UpdateGrant(grant); err != nil {
403+
return fmt.Errorf("failed to update grant: %w", err)
404+
}
405+
406+
return nil
407+
}
408+
330409
// databaseAdapter adapts the database to the tokens.Database interface
331410
type databaseAdapter struct {
332411
db *database.Database
@@ -984,12 +1063,19 @@ func (p *OAuthProxy) mcpProxyHandler(c *gin.Context) {
9841063
return
9851064
}
9861065

987-
// Update the token info with the new access token
988-
tokenInfo.Props["access_token"] = newTokenInfo.AccessToken
989-
if newTokenInfo.RefreshToken != "" {
990-
tokenInfo.Props["refresh_token"] = newTokenInfo.RefreshToken
1066+
// Update the grant with new token information
1067+
if err := p.updateGrant(tokenInfo.GrantID, tokenInfo.UserID, tokenInfo, newTokenInfo); err != nil {
1068+
log.Printf("Failed to update grant: %v", err)
1069+
c.JSON(http.StatusInternalServerError, gin.H{
1070+
"error": "server_error",
1071+
"error_description": "Failed to update grant with new token",
1072+
})
1073+
return
9911074
}
9921075

1076+
// Update the token info with the new access token for the current request
1077+
tokenInfo.Props["access_token"] = newTokenInfo.AccessToken
1078+
9931079
log.Printf("Successfully refreshed access token")
9941080
}
9951081
}
@@ -1042,6 +1128,7 @@ func (p *OAuthProxy) mcpProxyHandler(c *gin.Context) {
10421128
},
10431129
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
10441130
log.Printf("Proxy error: %v", err)
1131+
c.Abort()
10451132
},
10461133
}
10471134

tokens/jwt.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Database interface {
2323

2424
type TokenClaims struct {
2525
UserID string `json:"user_id"`
26+
GrantID string `json:"grant_id"`
2627
Props map[string]interface{} `json:"props,omitempty"`
2728
ExpiresAt time.Time `json:"expires_at"`
2829
}
@@ -95,6 +96,7 @@ func (tm *TokenManager) ValidateAccessToken(tokenString string) (*TokenClaims, e
9596
// Create TokenClaims with the grant's props
9697
claims := &TokenClaims{
9798
UserID: userID,
99+
GrantID: grantID,
98100
Props: grant.Props,
99101
ExpiresAt: tokenData.ExpiresAt,
100102
}
@@ -111,6 +113,7 @@ func (tm *TokenManager) GetTokenInfo(tokenString string) (*TokenInfo, error) {
111113

112114
return &TokenInfo{
113115
UserID: claims.UserID,
116+
GrantID: claims.GrantID,
114117
Props: claims.Props,
115118
ExpiresAt: claims.ExpiresAt,
116119
}, nil
@@ -119,6 +122,7 @@ func (tm *TokenManager) GetTokenInfo(tokenString string) (*TokenInfo, error) {
119122
// TokenInfo represents token information
120123
type TokenInfo struct {
121124
UserID string
125+
GrantID string
122126
Props map[string]interface{}
123127
ExpiresAt time.Time
124128
}

0 commit comments

Comments
 (0)