Skip to content

Commit f3419dd

Browse files
authored
Merge pull request #163 from thegrumpylion/pkce
implement PKCE for AuthorizationCode grant
2 parents 07c72de + 7b9faad commit f3419dd

File tree

12 files changed

+382
-79
lines changed

12 files changed

+382
-79
lines changed

const.go

+40
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
package oauth2
22

3+
import (
4+
"crypto/sha256"
5+
"encoding/base64"
6+
"strings"
7+
)
8+
39
// ResponseType the type of authorization request
410
type ResponseType string
511

@@ -34,3 +40,37 @@ func (gt GrantType) String() string {
3440
}
3541
return ""
3642
}
43+
44+
// CodeChallengeMethod PCKE method
45+
type CodeChallengeMethod string
46+
47+
const (
48+
// CodeChallengePlain PCKE Method
49+
CodeChallengePlain CodeChallengeMethod = "plain"
50+
// CodeChallengeS256 PCKE Method
51+
CodeChallengeS256 CodeChallengeMethod = "S256"
52+
)
53+
54+
func (ccm CodeChallengeMethod) String() string {
55+
if ccm == CodeChallengePlain ||
56+
ccm == CodeChallengeS256 {
57+
return string(ccm)
58+
}
59+
return ""
60+
}
61+
62+
// Validate code challenge
63+
func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
64+
switch ccm {
65+
case CodeChallengePlain:
66+
return cc == ver
67+
case CodeChallengeS256:
68+
s256 := sha256.Sum256([]byte(ver))
69+
// trim padding
70+
a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=")
71+
b := strings.TrimRight(cc, "=")
72+
return a == b
73+
default:
74+
return false
75+
}
76+
}

const_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package oauth2_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/go-oauth2/oauth2/v4"
7+
)
8+
9+
func TestValidatePlain(t *testing.T) {
10+
cc := oauth2.CodeChallengePlain
11+
if !cc.Validate("plaintest", "plaintest") {
12+
t.Fatal("not valid")
13+
}
14+
}
15+
16+
func TestValidateS256(t *testing.T) {
17+
cc := oauth2.CodeChallengeS256
18+
if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o=", "s256test") {
19+
t.Fatal("not valid")
20+
}
21+
}
22+
23+
func TestValidateS256NoPadding(t *testing.T) {
24+
cc := oauth2.CodeChallengeS256
25+
if !cc.Validate("W6YWc_4yHwYN-cGDgGmOMHF3l7KDy7VcRjf7q2FVF-o", "s256test") {
26+
t.Fatal("not valid")
27+
}
28+
}

errors/error.go

+3
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ var (
1313
ErrInvalidRefreshToken = errors.New("invalid refresh token")
1414
ErrExpiredAccessToken = errors.New("expired access token")
1515
ErrExpiredRefreshToken = errors.New("expired refresh token")
16+
ErrMissingCodeVerifier = errors.New("missing code verifier")
17+
ErrMissingCodeChallenge = errors.New("missing code challenge")
18+
ErrInvalidCodeChallenge = errors.New("invalid code challenge")
1619
)

errors/response.go

+39-30
Original file line numberDiff line numberDiff line change
@@ -34,42 +34,51 @@ func (r *Response) SetHeader(key, value string) {
3434

3535
// https://tools.ietf.org/html/rfc6749#section-5.2
3636
var (
37-
ErrInvalidRequest = errors.New("invalid_request")
38-
ErrUnauthorizedClient = errors.New("unauthorized_client")
39-
ErrAccessDenied = errors.New("access_denied")
40-
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
41-
ErrInvalidScope = errors.New("invalid_scope")
42-
ErrServerError = errors.New("server_error")
43-
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
44-
ErrInvalidClient = errors.New("invalid_client")
45-
ErrInvalidGrant = errors.New("invalid_grant")
46-
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
37+
ErrInvalidRequest = errors.New("invalid_request")
38+
ErrUnauthorizedClient = errors.New("unauthorized_client")
39+
ErrAccessDenied = errors.New("access_denied")
40+
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
41+
ErrInvalidScope = errors.New("invalid_scope")
42+
ErrServerError = errors.New("server_error")
43+
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
44+
ErrInvalidClient = errors.New("invalid_client")
45+
ErrInvalidGrant = errors.New("invalid_grant")
46+
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
47+
ErrCodeChallengeRquired = errors.New("invalid_request")
48+
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
49+
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
4750
)
4851

4952
// Descriptions error description
5053
var Descriptions = map[error]string{
51-
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
52-
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
53-
ErrAccessDenied: "The resource owner or authorization server denied the request",
54-
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
55-
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
56-
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
57-
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
58-
ErrInvalidClient: "Client authentication failed",
59-
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
60-
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
54+
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
55+
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
56+
ErrAccessDenied: "The resource owner or authorization server denied the request",
57+
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
58+
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
59+
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
60+
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
61+
ErrInvalidClient: "Client authentication failed",
62+
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
63+
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
64+
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
65+
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
66+
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
6167
}
6268

6369
// StatusCodes response error HTTP status code
6470
var StatusCodes = map[error]int{
65-
ErrInvalidRequest: 400,
66-
ErrUnauthorizedClient: 401,
67-
ErrAccessDenied: 403,
68-
ErrUnsupportedResponseType: 401,
69-
ErrInvalidScope: 400,
70-
ErrServerError: 500,
71-
ErrTemporarilyUnavailable: 503,
72-
ErrInvalidClient: 401,
73-
ErrInvalidGrant: 401,
74-
ErrUnsupportedGrantType: 401,
71+
ErrInvalidRequest: 400,
72+
ErrUnauthorizedClient: 401,
73+
ErrAccessDenied: 403,
74+
ErrUnsupportedResponseType: 401,
75+
ErrInvalidScope: 400,
76+
ErrServerError: 500,
77+
ErrTemporarilyUnavailable: 503,
78+
ErrInvalidClient: 401,
79+
ErrInvalidGrant: 401,
80+
ErrUnsupportedGrantType: 401,
81+
ErrCodeChallengeRquired: 400,
82+
ErrUnsupportedCodeChallengeMethod: 400,
83+
ErrInvalidCodeChallengeLen: 400,
7584
}

example/client/client.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package main
22

33
import (
44
"context"
5+
"crypto/sha256"
6+
"encoding/base64"
57
"encoding/json"
68
"fmt"
79
"io"
@@ -33,7 +35,9 @@ var (
3335

3436
func main() {
3537
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
36-
u := config.AuthCodeURL("xyz")
38+
u := config.AuthCodeURL("xyz",
39+
oauth2.SetAuthURLParam("code_challenge", genCodeChallengeS256("s256example")),
40+
oauth2.SetAuthURLParam("code_challenge_method", "S256"))
3741
http.Redirect(w, r, u, http.StatusFound)
3842
})
3943

@@ -49,7 +53,7 @@ func main() {
4953
http.Error(w, "Code not found", http.StatusBadRequest)
5054
return
5155
}
52-
token, err := config.Exchange(context.Background(), code)
56+
token, err := config.Exchange(context.Background(), code, oauth2.SetAuthURLParam("code_verifier", "s256example"))
5357
if err != nil {
5458
http.Error(w, err.Error(), http.StatusInternalServerError)
5559
return
@@ -130,3 +134,8 @@ func main() {
130134
log.Println("Client is running at 9094 port.Please open http://localhost:9094")
131135
log.Fatal(http.ListenAndServe(":9094", nil))
132136
}
137+
138+
func genCodeChallengeS256(s string) string {
139+
s256 := sha256.Sum256([]byte(s))
140+
return base64.URLEncoding.EncodeToString(s256[:])
141+
}

manage.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@ import (
88

99
// TokenGenerateRequest provide to generate the token request parameters
1010
type TokenGenerateRequest struct {
11-
ClientID string
12-
ClientSecret string
13-
UserID string
14-
RedirectURI string
15-
Scope string
16-
Code string
17-
Refresh string
18-
AccessTokenExp time.Duration
19-
Request *http.Request
11+
ClientID string
12+
ClientSecret string
13+
UserID string
14+
RedirectURI string
15+
Scope string
16+
Code string
17+
CodeChallenge string
18+
CodeChallengeMethod CodeChallengeMethod
19+
Refresh string
20+
CodeVerifier string
21+
AccessTokenExp time.Duration
22+
Request *http.Request
2023
}
2124

2225
// Manager authorization management interface

manage/manager.go

+29
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType,
176176
if exp := tgr.AccessTokenExp; exp > 0 {
177177
ti.SetAccessExpiresIn(exp)
178178
}
179+
if tgr.CodeChallenge != "" {
180+
ti.SetCodeChallenge(tgr.CodeChallenge)
181+
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
182+
}
179183

180184
tv, err := m.authorizeGenerate.Token(ctx, td)
181185
if err != nil {
@@ -251,6 +255,28 @@ func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.To
251255
return ti, nil
252256
}
253257

258+
func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
259+
cc := ti.GetCodeChallenge()
260+
// early return
261+
if cc == "" && ver == "" {
262+
return nil
263+
}
264+
if cc == "" {
265+
return errors.ErrMissingCodeVerifier
266+
}
267+
if ver == "" {
268+
return errors.ErrMissingCodeVerifier
269+
}
270+
ccm := ti.GetCodeChallengeMethod()
271+
if ccm.String() == "" {
272+
ccm = oauth2.CodeChallengePlain
273+
}
274+
if !ccm.Validate(cc, ver) {
275+
return errors.ErrInvalidCodeChallenge
276+
}
277+
return nil
278+
}
279+
254280
// GenerateAccessToken generate the access token
255281
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
256282
cli, err := m.GetClient(ctx, tgr.ClientID)
@@ -275,6 +301,9 @@ func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType,
275301
if err != nil {
276302
return nil, err
277303
}
304+
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
305+
return nil, err
306+
}
278307
tgr.UserID = ti.GetUserID()
279308
tgr.Scope = ti.GetScope()
280309
if exp := ti.GetAccessExpiresIn(); exp > 0 {

model.go

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ type (
3737
SetCodeCreateAt(time.Time)
3838
GetCodeExpiresIn() time.Duration
3939
SetCodeExpiresIn(time.Duration)
40+
GetCodeChallenge() string
41+
SetCodeChallenge(string)
42+
GetCodeChallengeMethod() CodeChallengeMethod
43+
SetCodeChallengeMethod(CodeChallengeMethod)
4044

4145
GetAccess() string
4246
SetAccess(string)

models/token.go

+35-13
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@ func NewToken() *Token {
1313

1414
// Token token model
1515
type Token struct {
16-
ClientID string `bson:"ClientID"`
17-
UserID string `bson:"UserID"`
18-
RedirectURI string `bson:"RedirectURI"`
19-
Scope string `bson:"Scope"`
20-
Code string `bson:"Code"`
21-
CodeCreateAt time.Time `bson:"CodeCreateAt"`
22-
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
23-
Access string `bson:"Access"`
24-
AccessCreateAt time.Time `bson:"AccessCreateAt"`
25-
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
26-
Refresh string `bson:"Refresh"`
27-
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
28-
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
16+
ClientID string `bson:"ClientID"`
17+
UserID string `bson:"UserID"`
18+
RedirectURI string `bson:"RedirectURI"`
19+
Scope string `bson:"Scope"`
20+
Code string `bson:"Code"`
21+
CodeChallenge string `bson:"CodeChallenge"`
22+
CodeChallengeMethod string `bson:"CodeChallengeMethod"`
23+
CodeCreateAt time.Time `bson:"CodeCreateAt"`
24+
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
25+
Access string `bson:"Access"`
26+
AccessCreateAt time.Time `bson:"AccessCreateAt"`
27+
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
28+
Refresh string `bson:"Refresh"`
29+
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
30+
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
2931
}
3032

3133
// New create to token model instance
@@ -103,6 +105,26 @@ func (t *Token) SetCodeExpiresIn(exp time.Duration) {
103105
t.CodeExpiresIn = exp
104106
}
105107

108+
// GetCodeChallenge challenge code
109+
func (t *Token) GetCodeChallenge() string {
110+
return t.CodeChallenge
111+
}
112+
113+
// SetCodeChallenge challenge code
114+
func (t *Token) SetCodeChallenge(code string) {
115+
t.CodeChallenge = code
116+
}
117+
118+
// GetCodeChallengeMethod challenge method
119+
func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
120+
return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
121+
}
122+
123+
// SetCodeChallengeMethod challenge method
124+
func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
125+
t.CodeChallengeMethod = string(method)
126+
}
127+
106128
// GetAccess access Token
107129
func (t *Token) GetAccess() string {
108130
return t.Access

0 commit comments

Comments
 (0)