Skip to content

Commit c36e1f8

Browse files
committed
get the refresh token from the request with a customizable handler function
1 parent 4626899 commit c36e1f8

File tree

3 files changed

+91
-4
lines changed

3 files changed

+91
-4
lines changed

server/handler.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ type (
5151

5252
// ResponseTokenHandler response token handling
5353
ResponseTokenHandler func(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error
54+
55+
// Handler to fetch the refresh token from the request
56+
RefreshTokenResolveHandler func(r *http.Request) (string, error)
5457
)
5558

5659
// ClientFormHandler get client data from form
@@ -71,3 +74,21 @@ func ClientBasicHandler(r *http.Request) (string, string, error) {
7174
}
7275
return username, password, nil
7376
}
77+
78+
func RefreshTokenFormResolveHandler(r *http.Request) (string, error) {
79+
rt := r.FormValue("refresh_token")
80+
if rt == "" {
81+
return "", errors.ErrInvalidRequest
82+
}
83+
84+
return rt, nil
85+
}
86+
87+
func RefreshTokenCookieResolveHandler(r *http.Request) (string, error) {
88+
c, err := r.Cookie("refresh_token")
89+
if err != nil {
90+
return "", errors.ErrInvalidRequest
91+
}
92+
93+
return c.Value, nil
94+
}

server/handler_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package server
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"net/url"
7+
"strings"
8+
"testing"
9+
"time"
10+
11+
"github.com/go-oauth2/oauth2/v4/errors"
12+
. "github.com/smartystreets/goconvey/convey"
13+
)
14+
15+
func TestRefreshTokenFormResolveHandler(t *testing.T) {
16+
Convey("Correct Request", t, func() {
17+
f := url.Values{}
18+
f.Add("refresh_token", "test_token")
19+
20+
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
21+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
22+
23+
token, err := RefreshTokenFormResolveHandler(r)
24+
So(err, ShouldBeNil)
25+
So(token, ShouldEqual, "test_token")
26+
})
27+
28+
Convey("Missing Refresh Token", t, func() {
29+
r := httptest.NewRequest("POST", "/", nil)
30+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
31+
32+
token, err := RefreshTokenFormResolveHandler(r)
33+
So(err, ShouldBeError, errors.ErrInvalidRequest)
34+
So(token, ShouldBeEmpty)
35+
})
36+
}
37+
38+
func TestRefreshTokenCookieResolveHandler(t *testing.T) {
39+
Convey("Correct Request", t, func() {
40+
r := httptest.NewRequest(http.MethodPost, "/", nil)
41+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
42+
r.AddCookie(&http.Cookie{
43+
Name: "refresh_token",
44+
Value: "test_token",
45+
HttpOnly: true,
46+
Path: "/",
47+
Domain: ".example.com",
48+
Expires: time.Now().Add(time.Hour),
49+
})
50+
51+
token, err := RefreshTokenCookieResolveHandler(r)
52+
So(err, ShouldBeNil)
53+
So(token, ShouldEqual, "test_token")
54+
})
55+
56+
Convey("Missing Refresh Token", t, func() {
57+
r := httptest.NewRequest("POST", "/", nil)
58+
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
59+
60+
token, err := RefreshTokenCookieResolveHandler(r)
61+
So(err, ShouldBeError, errors.ErrInvalidRequest)
62+
So(token, ShouldBeEmpty)
63+
})
64+
}

server/server.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
2525
Manager: manager,
2626
}
2727

28-
// default handler
28+
// default handlers
2929
srv.ClientInfoHandler = ClientBasicHandler
30+
srv.RefreshTokenResolveHandler = RefreshTokenFormResolveHandler
3031

3132
srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
3233
return "", errors.ErrAccessDenied
@@ -56,6 +57,7 @@ type Server struct {
5657
AccessTokenExpHandler AccessTokenExpHandler
5758
AuthorizeScopeHandler AuthorizeScopeHandler
5859
ResponseTokenHandler ResponseTokenHandler
60+
RefreshTokenResolveHandler RefreshTokenResolveHandler
5961
}
6062

6163
func (s *Server) handleError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
@@ -367,10 +369,10 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oau
367369
case oauth2.ClientCredentials:
368370
tgr.Scope = r.FormValue("scope")
369371
case oauth2.Refreshing:
370-
tgr.Refresh = r.FormValue("refresh_token")
372+
tgr.Refresh, err = s.RefreshTokenResolveHandler(r)
371373
tgr.Scope = r.FormValue("scope")
372-
if tgr.Refresh == "" {
373-
return "", nil, errors.ErrInvalidRequest
374+
if err != nil {
375+
return "", nil, err
374376
}
375377
}
376378
return gt, tgr, nil

0 commit comments

Comments
 (0)