From 2ca298077e2d7a21222893350382d4c281c19c71 Mon Sep 17 00:00:00 2001 From: sgal Date: Sun, 29 Jan 2023 17:18:25 +0100 Subject: [PATCH 1/9] feat: add token hooks for all grant types --- driver/config/provider.go | 29 ++- driver/registry_base.go | 3 + ...call_refresh_token_hook_if_configured.json | 7 +- ...call_refresh_token_hook_if_configured.json | 7 +- ...call_refresh_token_hook_if_configured.json | 7 +- ...call_refresh_token_hook_if_configured.json | 7 +- ...call_refresh_token_hook_if_configured.json | 7 +- ...call_refresh_token_hook_if_configured.json | 7 +- oauth2/hook.go | 221 +++++++++++------- oauth2/oauth2_auth_code_test.go | 13 +- spec/config.json | 18 ++ 11 files changed, 231 insertions(+), 95 deletions(-) diff --git a/driver/config/provider.go b/driver/config/provider.go index e2d393a3f6b..a2ae14f0422 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -92,7 +92,10 @@ const ( KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" - KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101 + KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101 + KeyAuthorizationCodeHookURL = "oauth2.authorization_code_hook" // #nosec G101 + KeyClientCredentialsHookURL = "oauth2.client_credentials_hook" // #nosec G101 + KeyJWTBearerHookURL = "oauth2.jwt_bearer_hook" // #nosec G101 KeyDevelopmentMode = "dev" ) @@ -420,6 +423,22 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou return s } +func (p *DefaultProvider) AuthorizationCodeHookURL(ctx context.Context) *url.URL { + if len(p.getProvider(ctx).String(KeyAuthorizationCodeHookURL)) == 0 { + return nil + } + + return p.getProvider(ctx).RequestURIF(KeyAuthorizationCodeHookURL, nil) +} + +func (p *DefaultProvider) ClientCredentialsHookURL(ctx context.Context) *url.URL { + if len(p.getProvider(ctx).String(KeyClientCredentialsHookURL)) == 0 { + return nil + } + + return p.getProvider(ctx).RequestURIF(KeyClientCredentialsHookURL, nil) +} + func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL { if len(p.getProvider(ctx).String(KeyRefreshTokenHookURL)) == 0 { return nil @@ -428,6 +447,14 @@ func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL { return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil) } +func (p *DefaultProvider) JWTBearerRefreshHookURL(ctx context.Context) *url.URL { + if len(p.getProvider(ctx).String(KeyJWTBearerHookURL)) == 0 { + return nil + } + + return p.getProvider(ctx).RequestURIF(KeyJWTBearerHookURL, nil) +} + func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool { return p.p.Bool(KeyDBIgnoreUnknownTableColumns) } diff --git a/driver/registry_base.go b/driver/registry_base.go index 84afce8455e..b78ce21759c 100644 --- a/driver/registry_base.go +++ b/driver/registry_base.go @@ -506,6 +506,9 @@ func (m *RegistryBase) AccessRequestHooks() []oauth2.AccessRequestHook { if m.arhs == nil { m.arhs = []oauth2.AccessRequestHook{ oauth2.RefreshTokenHook(m), + oauth2.AuthorizationCodeHook(m), + oauth2.ClientCredentialsHook(m), + oauth2.JWTBearerHook(m), } } return m.arhs diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index 66fbfb5af98..f5891092948 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -41,7 +41,12 @@ "granted_audience": [], "grant_types": [ "refresh_token" - ] + ], + "payload": { + "grant_type": [ + "refresh_token" + ] + } }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/hook.go b/oauth2/hook.go index e46add43b6c..c8d90f2da8e 100644 --- a/oauth2/hook.go +++ b/oauth2/hook.go @@ -7,7 +7,9 @@ import ( "bytes" "context" "encoding/json" + "fmt" "net/http" + "net/url" "github.com/hashicorp/go-retryablehttp" @@ -34,12 +36,14 @@ type Requester struct { GrantedAudience []string `json:"granted_audience"` // GrantTypes is the requests grant types. GrantTypes []string `json:"grant_types"` + // RequestBody is the requests payload. + Payload map[string][]string `json:"payload"` } -// RefreshTokenHookRequest is the request body sent to the refresh token hook. +// TokenHookRequest is the request body sent to token hooks. // // swagger:ignore -type RefreshTokenHookRequest struct { +type TokenHookRequest struct { // Subject is the identifier of the authenticated end-user. Subject string `json:"subject"` // Session is the request's session.. @@ -54,10 +58,10 @@ type RefreshTokenHookRequest struct { GrantedAudience []string `json:"granted_audience"` } -// RefreshTokenHookResponse is the response body received from the refresh token hook. +// TokenHookResponse is the response body received from token hooks. // // swagger:ignore -type RefreshTokenHookResponse struct { +type TokenHookResponse struct { // Session is the session data returned by the hook. Session consent.AcceptOAuth2ConsentRequestSession `json:"session"` } @@ -66,103 +70,154 @@ type RefreshTokenHookResponse struct { func RefreshTokenHook(reg interface { config.Provider x.HTTPClientProvider -}) AccessRequestHook { +}, +) AccessRequestHook { return func(ctx context.Context, requester fosite.AccessRequester) error { hookURL := reg.Config().TokenRefreshHookURL(ctx) if hookURL == nil { return nil } + return callHook(ctx, reg, requester, "refresh_token", hookURL) + } +} - if !requester.GetGrantTypes().ExactOne("refresh_token") { +// AuthorizationCodeHook is an AccessRequestHook called for `authorization_code` grant type. +func AuthorizationCodeHook(reg interface { + config.Provider + x.HTTPClientProvider +}, +) AccessRequestHook { + return func(ctx context.Context, requester fosite.AccessRequester) error { + hookURL := reg.Config().AuthorizationCodeHookURL(ctx) + if hookURL == nil { return nil } + return callHook(ctx, reg, requester, "authorization_code", hookURL) + } +} - session, ok := requester.GetSession().(*Session) - if !ok { +// ClientCredentialsHook is an AccessRequestHook called for `client_credentials` grant type. +func ClientCredentialsHook(reg interface { + config.Provider + x.HTTPClientProvider +}, +) AccessRequestHook { + return func(ctx context.Context, requester fosite.AccessRequester) error { + hookURL := reg.Config().ClientCredentialsHookURL(ctx) + if hookURL == nil { return nil } + return callHook(ctx, reg, requester, "client_credentials", hookURL) + } +} - requesterInfo := Requester{ - ClientID: requester.GetClient().GetID(), - GrantedScopes: requester.GetGrantedScopes(), - GrantedAudience: requester.GetGrantedAudience(), - GrantTypes: requester.GetGrantTypes(), +// JWTBearerHook is an AccessRequestHook called for `urn:ietf:params:oauth:grant-type:jwt-bearer` grant type. +func JWTBearerHook(reg interface { + config.Provider + x.HTTPClientProvider +}, +) AccessRequestHook { + return func(ctx context.Context, requester fosite.AccessRequester) error { + hookURL := reg.Config().JWTBearerRefreshHookURL(ctx) + if hookURL == nil { + return nil } + return callHook(ctx, reg, requester, "urn:ietf:params:oauth:grant-type:jwt-bearer", hookURL) + } +} - reqBody := RefreshTokenHookRequest{ - Session: session, - Requester: requesterInfo, - Subject: session.GetSubject(), - ClientID: requester.GetClient().GetID(), - GrantedScopes: requester.GetGrantedScopes(), - GrantedAudience: requester.GetGrantedAudience(), - } - reqBodyBytes, err := json.Marshal(&reqBody) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription("An error occurred while encoding the refresh token hook."). - WithDebugf("Unable to encode the refresh token hook body: %s", err), - ) - } +func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, hookType string, hookURL *url.URL) error { - req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, hookURL.String(), bytes.NewReader(reqBodyBytes)) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription("An error occurred while preparing the refresh token hook."). - WithDebugf("Unable to prepare the HTTP Request: %s", err), - ) - } - req.Header.Set("Content-Type", "application/json; charset=UTF-8") - - resp, err := reg.HTTPClient(ctx).Do(req) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription("An error occurred while executing the refresh token hook."). - WithDebugf("Unable to execute HTTP Request: %s", err), - ) - } - defer resp.Body.Close() + if !requester.GetGrantTypes().ExactOne(hookType) { + return nil + } - switch resp.StatusCode { - case http.StatusOK: - // Token refresh permitted with new session data - case http.StatusNoContent: - // Token refresh is permitted without overriding session data - return nil - case http.StatusForbidden: - return errorsx.WithStack( - fosite.ErrAccessDenied. - WithDescription("The refresh token hook target responded with an error."). - WithDebugf("Refresh token hook responded with HTTP status code: %s", resp.Status), - ) - default: - return errorsx.WithStack( - fosite.ErrServerError. - WithDescription("The refresh token hook target responded with an error."). - WithDebugf("Refresh token hook responded with HTTP status code: %s", resp.Status), - ) - } + session, ok := requester.GetSession().(*Session) + if !ok { + return nil + } - var respBody RefreshTokenHookResponse - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription("The refresh token hook target responded with an error."). - WithDebugf("Response from refresh token hook could not be decoded: %s", err), - ) - } + requesterInfo := Requester{ + ClientID: requester.GetClient().GetID(), + GrantedScopes: requester.GetGrantedScopes(), + GrantedAudience: requester.GetGrantedAudience(), + GrantTypes: requester.GetGrantTypes(), + Payload: requester.GetRequestForm(), + } + + reqBody := TokenHookRequest{ + Session: session, + Requester: requesterInfo, + Subject: session.GetSubject(), + ClientID: requester.GetClient().GetID(), + GrantedScopes: requester.GetGrantedScopes(), + GrantedAudience: requester.GetGrantedAudience(), + } + reqBodyBytes, err := json.Marshal(&reqBody) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription(fmt.Sprintf("An error occurred while encoding the %s hook.", hookType)). + WithDebugf("Unable to encode the %v hook body: %s", hookType, err), + ) + } + + req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, hookURL.String(), bytes.NewReader(reqBodyBytes)) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription(fmt.Sprintf("An error occurred while preparing the %s hook.", hookType)). + WithDebugf("Unable to prepare the HTTP Request: %s", err), + ) + } + req.Header.Set("Content-Type", "application/json; charset=UTF-8") + + resp, err := reg.HTTPClient(ctx).Do(req) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription(fmt.Sprintf("An error occurred while executing the %s hook.", hookType)). + WithDebugf("Unable to execute HTTP Request: %s", err), + ) + } + defer resp.Body.Close() - // Overwrite existing session data (extra claims). - session.Extra = respBody.Session.AccessToken - idTokenClaims := session.IDTokenClaims() - idTokenClaims.Extra = respBody.Session.IDToken + switch resp.StatusCode { + case http.StatusOK: + // Token permitted with new session data + case http.StatusNoContent: + // Token is permitted without overriding session data return nil + case http.StatusForbidden: + return errorsx.WithStack( + fosite.ErrAccessDenied. + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). + WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", hookType, resp.Status)), + ) + default: + return errorsx.WithStack( + fosite.ErrServerError. + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). + WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", hookType, resp.Status)), + ) } + + var respBody TokenHookResponse + if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). + WithDebugf(fmt.Sprintf("Response from %s hook could not be decoded: %s", hookType, err)), + ) + } + + // Overwrite existing session data (extra claims). + session.Extra = respBody.Session.AccessToken + idTokenClaims := session.IDTokenClaims() + idTokenClaims.Extra = respBody.Session.IDToken + return nil } diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 1a6bbe4275d..236e6741912 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -1024,8 +1024,9 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} expectedSubject := "foo" + expectedPayload := map[string][]string(map[string][]string{"grant_type": {"refresh_token"}, "refresh_token": {refreshedToken.RefreshToken}}) - var hookReq hydraoauth2.RefreshTokenHookRequest + var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) require.Equal(t, hookReq.Subject, expectedSubject) require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) @@ -1038,6 +1039,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { require.NotEmpty(t, hookReq.Requester) require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + require.Equal(t, hookReq.Requester.Payload, expectedPayload) except := []string{ "session.kid", @@ -1047,6 +1049,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { "session.id_token.id_token_claims.exp", "session.id_token.id_token_claims.rat", "session.id_token.id_token_claims.auth_time", + "requester.payload.refresh_token", } snapshotx.SnapshotTExcept(t, hookReq, except) @@ -1054,7 +1057,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { "hooked": true, } - hookResp := hydraoauth2.RefreshTokenHookResponse{ + hookResp := hydraoauth2.TokenHookResponse{ Session: consent.AcceptOAuth2ConsentRequestSession{ AccessToken: claims, IDToken: claims, @@ -1131,7 +1134,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { var errBody fosite.RFC6749ErrorJson require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "An error occurred while executing the refresh token hook.", errBody.Description) + require.Equal(t, "An error occurred while executing the refresh_token hook.", errBody.Description) }) t.Run("should fail token refresh with `access_denied` if hook denied the request", func(t *testing.T) { @@ -1150,7 +1153,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { var errBody fosite.RFC6749ErrorJson require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) - require.Equal(t, "The refresh token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + require.Equal(t, "The refresh_token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) }) t.Run("should fail token refresh with `server_error` if hook response is malformed", func(t *testing.T) { @@ -1169,7 +1172,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { var errBody fosite.RFC6749ErrorJson require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "The refresh token hook target responded with an error.", errBody.Description) + require.Equal(t, "The refresh_token hook target responded with an error.", errBody.Description) }) t.Run("refreshing old token should no longer work", func(t *testing.T) { diff --git a/spec/config.json b/spec/config.json index f2cf7e46299..d0db9412fa7 100644 --- a/spec/config.json +++ b/spec/config.json @@ -975,6 +975,24 @@ "description": "Sets the refresh token hook endpoint. If set it will be called during token refresh to receive updated token claims.", "format": "uri", "examples": ["https://my-example.app/token-refresh-hook"] + }, + "authorization_code_hook": { + "type": "string", + "description": "Sets the authorization code hook endpoint. If set it will be called while providing token to customize claims.", + "format": "uri", + "examples": ["https://my-example.app/token-authorization-code-hook"] + }, + "client_credentials_hook": { + "type": "string", + "description": "Sets the client credentials hook endpoint. If set it will be called while providing token to customize claims.", + "format": "uri", + "examples": ["https://my-example.app/token-client-credentials-hook"] + }, + "jwt_bearer_hook": { + "type": "string", + "description": "Sets the jwt bearer hook endpoint. If set it will be called while providing token to customize claims.", + "format": "uri", + "examples": ["https://my-example.app/token-jwt-bearer-hook"] } } }, From 1e419f326004dd219cbcbacf1b1292fd365f4273 Mon Sep 17 00:00:00 2001 From: sgal Date: Sun, 29 Jan 2023 18:12:10 +0100 Subject: [PATCH 2/9] Add tests for webhook in client_credentials grant_type --- oauth2/oauth2_client_credentials_test.go | 157 +++++++++++++++++++++-- 1 file changed, 149 insertions(+), 8 deletions(-) diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index b151f267679..90bdbce36e3 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -7,6 +7,8 @@ import ( "context" "encoding/json" "math" + "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -20,7 +22,9 @@ import ( goauth2 "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/internal/testhelpers" + hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/x/contextx" hc "github.com/ory/hydra/v2/client" @@ -75,7 +79,7 @@ func TestClientCredentials(t *testing.T) { return string(out) } - var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time) { + var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) { introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: conf.ClientSecret}, token.AccessToken, admin) check := func(res gjson.Result) { @@ -87,6 +91,10 @@ func TestClientCredentials(t *testing.T) { requirex.EqualTime(t, expectedExp, time.Unix(res.Get("exp").Int(), 0), time.Second) assert.EqualValues(t, encodeOr(t, conf.EndpointParams["audience"], "[]"), res.Get("aud").Raw, "%s", res.Raw) + + if checkExtraClaims { + require.True(t, res.Get("ext.hooked").Bool()) + } } check(introspection) @@ -108,10 +116,10 @@ func TestClientCredentials(t *testing.T) { check(jwtClaims) } - var getAndInspectToken = func(t *testing.T, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time) { + var getAndInspectToken = func(t *testing.T, cl *hc.Client, conf clientcredentials.Config, strategy string, expectedExp time.Time, checkExtraClaims bool) { token, err := getToken(t, conf) require.NoError(t, err) - inspectToken(t, token, cl, conf, strategy, expectedExp) + inspectToken(t, token, cl, conf, strategy, expectedExp, checkExtraClaims) } t.Run("case=should fail because audience is not allowed", func(t *testing.T) { @@ -134,7 +142,7 @@ func TestClientCredentials(t *testing.T) { reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) cl, conf := newClient(t) - getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false) } } @@ -149,7 +157,7 @@ func TestClientCredentials(t *testing.T) { cl, conf := newClient(t) conf.EndpointParams = url.Values{} - getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false) } } @@ -164,7 +172,7 @@ func TestClientCredentials(t *testing.T) { cl, conf := newClient(t) conf.Scopes = []string{} - getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false) } } @@ -188,7 +196,7 @@ func TestClientCredentials(t *testing.T) { // We reset this so that introspectToken is going to check for the default scope. conf.Scopes = defaultScope - inspectToken(t, token, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + inspectToken(t, token, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), false) } } @@ -211,7 +219,7 @@ func TestClientCredentials(t *testing.T) { Audience: []string{"https://api.ory.sh/"}, }) testhelpers.UpdateClientTokenLifespans(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: conf.ClientSecret}, cl.GetID(), testhelpers.TestLifespans, admin) - getAndInspectToken(t, cl, conf, strategy, time.Now().Add(testhelpers.TestLifespans.ClientCredentialsGrantAccessTokenLifespan.Duration)) + getAndInspectToken(t, cl, conf, strategy, time.Now().Add(testhelpers.TestLifespans.ClientCredentialsGrantAccessTokenLifespan.Duration), false) } } @@ -241,4 +249,137 @@ func TestClientCredentials(t *testing.T) { t.Run("strategy=opaque", run("opaque")) t.Run("strategy=jwt", run("jwt")) }) + + t.Run("should call token hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + scope := "foobar" + audience := []string{"https://api.ory.sh/"} + + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{"foobar"} + expectedGrantedAudience := []string{"https://api.ory.sh/"} + expectedPayload := map[string][]string(map[string][]string{"audience": audience, "grant_type": {"client_credentials"}, "scope": {scope}}) + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) + require.NotEmpty(t, hookReq.Requester) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + require.Equal(t, hookReq.Requester.Payload, expectedPayload) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + + secret := uuid.New().String() + cl, conf := newCustomClient(t, &hc.Client{ + Secret: secret, + RedirectURIs: []string{public.URL + "/callback"}, + ResponseTypes: []string{"token"}, + GrantTypes: []string{"client_credentials"}, + Scope: scope, + Audience: audience, + }) + getAndInspectToken(t, cl, conf, strategy, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx)), true) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + + _, conf := newClient(t) + + _, err := getToken(t, conf) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook denied the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + + _, conf := newClient(t) + + _, err := getToken(t, conf) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + + _, conf := newClient(t) + + _, err := getToken(t, conf) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) } From 05606f9b8f0c4894cddf7dd29cac45c50bd5ba74 Mon Sep 17 00:00:00 2001 From: sgal Date: Sun, 29 Jan 2023 18:35:33 +0100 Subject: [PATCH 3/9] Add tests for webhook in jwt_bearer grant_type --- oauth2/oauth2_jwt_bearer_test.go | 186 ++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 3 deletions(-) diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index d77e8d5a9ff..f3006b28ff0 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "net/url" "strings" "testing" @@ -19,7 +20,9 @@ import ( "gopkg.in/square/go-jose.v2" "github.com/ory/fosite/token/jwt" + "github.com/ory/hydra/v2/consent" "github.com/ory/hydra/v2/jwk" + hydraoauth2 "github.com/ory/hydra/v2/oauth2" "github.com/ory/hydra/v2/oauth2/trust" "github.com/stretchr/testify/assert" @@ -65,7 +68,7 @@ func TestJWTBearer(t *testing.T) { return conf.Token(context.Background()) } - var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, grant trust.Grant) { + var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, grant trust.Grant, checkExtraClaims bool) { introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: cl.Secret}, token.AccessToken, admin) check := func(res gjson.Result) { @@ -77,6 +80,10 @@ func TestJWTBearer(t *testing.T) { assert.True(t, res.Get("exp").Int() >= res.Get("iat").Int()+int64(reg.Config().GetAccessTokenLifespan(ctx).Seconds()), "%s", res.Raw) assert.EqualValues(t, fmt.Sprintf(`["%s"]`, reg.Config().OAuth2TokenURL(ctx).String()), res.Get("aud").Raw, "%s", res.Raw) + + if checkExtraClaims { + require.True(t, res.Get("ext.hooked").Bool()) + } } check(introspection) @@ -248,7 +255,7 @@ func TestJWTBearer(t *testing.T) { result, err := getToken(t, conf) require.NoError(t, err) - inspectToken(t, result, client, strategy, trustGrant) + inspectToken(t, result, client, strategy, trustGrant, false) } } @@ -288,7 +295,180 @@ func TestJWTBearer(t *testing.T) { require.NoError(t, json.Unmarshal(body, &result)) assert.NotEmpty(t, result.AccessToken, "%s", body) - inspectToken(t, &result, client, strategy, trustGrant) + inspectToken(t, &result, client, strategy, trustGrant, false) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should call token hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + audience := reg.Config().OAuth2TokenURL(ctx).String() + grantType := "urn:ietf:params:oauth:grant-type:jwt-bearer" + + token, _, err := signer.Generate(ctx, jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": trustGrant.Issuer, + "sub": trustGrant.Subject, + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + }, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}) + require.NoError(t, err) + + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{client.Scope} + expectedGrantedAudience := []string{audience} + expectedPayload := map[string][]string(map[string][]string{"grant_type": {grantType}, "assertion": {token}, "scope": {client.Scope}}) + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) + require.NotEmpty(t, hookReq.Requester) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + require.Equal(t, hookReq.Requester.Payload, expectedPayload) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + + conf := newConf(client) + conf.EndpointParams = url.Values{"grant_type": {grantType}, "assertion": {token}} + + result, err := getToken(t, conf) + require.NoError(t, err) + + inspectToken(t, result, client, strategy, trustGrant, true) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + + token, _, err := signer.Generate(ctx, jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": trustGrant.Issuer, + "sub": trustGrant.Subject, + "aud": reg.Config().OAuth2TokenURL(ctx).String(), + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + }, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}) + require.NoError(t, err) + + conf := newConf(client) + conf.EndpointParams = url.Values{"grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"}, "assertion": {token}} + + _, tokenError := getToken(t, conf) + require.Error(t, tokenError) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook denied the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + + token, _, err := signer.Generate(ctx, jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": trustGrant.Issuer, + "sub": trustGrant.Subject, + "aud": reg.Config().OAuth2TokenURL(ctx).String(), + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + }, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}) + require.NoError(t, err) + + conf := newConf(client) + conf.EndpointParams = url.Values{"grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"}, "assertion": {token}} + + _, tokenError := getToken(t, conf) + require.Error(t, tokenError) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should fail token if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + + token, _, err := signer.Generate(ctx, jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": trustGrant.Issuer, + "sub": trustGrant.Subject, + "aud": reg.Config().OAuth2TokenURL(ctx).String(), + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + }, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}) + require.NoError(t, err) + + conf := newConf(client) + conf.EndpointParams = url.Values{"grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"}, "assertion": {token}} + + _, tokenError := getToken(t, conf) + require.Error(t, tokenError) } } From 923557c3ecc292b12cfae1a72166b37efe6b506f Mon Sep 17 00:00:00 2001 From: sgal Date: Sun, 29 Jan 2023 20:50:08 +0100 Subject: [PATCH 4/9] Add tests for hook in authorization_code grant_type --- ...call_refresh_token_hook_if_configured.json | 6 +- ...call_refresh_token_hook_if_configured.json | 6 +- ...call_refresh_token_hook_if_configured.json | 6 +- ...call_refresh_token_hook_if_configured.json | 6 +- ...call_refresh_token_hook_if_configured.json | 6 +- ...call_refresh_token_hook_if_configured.json | 6 +- oauth2/hook.go | 9 +- oauth2/oauth2_auth_code_test.go | 199 +++++++++++++++++- 8 files changed, 207 insertions(+), 37 deletions(-) diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index f5891092948..d9b61fd7865 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -42,11 +42,7 @@ "grant_types": [ "refresh_token" ], - "payload": { - "grant_type": [ - "refresh_token" - ] - } + "payload": {} }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/hook.go b/oauth2/hook.go index c8d90f2da8e..21ded063594 100644 --- a/oauth2/hook.go +++ b/oauth2/hook.go @@ -127,7 +127,6 @@ func JWTBearerHook(reg interface { } func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, hookType string, hookURL *url.URL) error { - if !requester.GetGrantTypes().ExactOne(hookType) { return nil } @@ -137,12 +136,18 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac return nil } + payload := map[string][]string{} + + if requester.GetGrantTypes().ExactOne("urn:ietf:params:oauth:grant-type:jwt-bearer") || requester.GetGrantTypes().ExactOne("client_credentials") { + payload = requester.GetRequestForm() + } + requesterInfo := Requester{ ClientID: requester.GetClient().GetID(), GrantedScopes: requester.GetGrantedScopes(), GrantedAudience: requester.GetGrantedAudience(), GrantTypes: requester.GetGrantTypes(), - Payload: requester.GetRequestForm(), + Payload: payload, } reqBody := TokenHookRequest{ diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 236e6741912..c748d10c299 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -235,7 +235,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, i.Get("active").Bool(), "%s", i) assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) return i } @@ -260,7 +260,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) - assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) assert.EqualValues(t, `["hydra","offline","openid"]`, i.Get("scp").Raw, "%s", i) return i } @@ -682,6 +682,197 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) } }) + + t.Run("case=add ext claims from hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{"foo": "bar"}) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + + assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + + // NOTE: using introspect to cover both jwt and opaque strategies + accessTokenClaims := introspectAccessToken(t, conf, token, subject) + require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) + + idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + require.True(t, idTokenClaims.Get("hooked").Bool()) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) } // TestAuthCodeWithMockStrategy runs the authorization_code flow against various ConsentStrategy scenarios. @@ -1024,7 +1215,6 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} expectedSubject := "foo" - expectedPayload := map[string][]string(map[string][]string{"grant_type": {"refresh_token"}, "refresh_token": {refreshedToken.RefreshToken}}) var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) @@ -1039,7 +1229,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { require.NotEmpty(t, hookReq.Requester) require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) - require.Equal(t, hookReq.Requester.Payload, expectedPayload) + require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) except := []string{ "session.kid", @@ -1049,7 +1239,6 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { "session.id_token.id_token_claims.exp", "session.id_token.id_token_claims.rat", "session.id_token.id_token_claims.auth_time", - "requester.payload.refresh_token", } snapshotx.SnapshotTExcept(t, hookReq, except) From 4d5fb1421c3078203ccb800b6e628ee4b0c3fb04 Mon Sep 17 00:00:00 2001 From: sgal Date: Mon, 30 Jan 2023 08:11:57 +0100 Subject: [PATCH 5/9] Fix openapi description fro payload field --- oauth2/hook.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauth2/hook.go b/oauth2/hook.go index 21ded063594..9cef02f430b 100644 --- a/oauth2/hook.go +++ b/oauth2/hook.go @@ -36,7 +36,7 @@ type Requester struct { GrantedAudience []string `json:"granted_audience"` // GrantTypes is the requests grant types. GrantTypes []string `json:"grant_types"` - // RequestBody is the requests payload. + // Payload is the requests payload. Payload map[string][]string `json:"payload"` } From fb5bd4b2d81877617789ea962c77af1cd8650d72 Mon Sep 17 00:00:00 2001 From: sgal Date: Tue, 7 Mar 2023 14:55:49 +0100 Subject: [PATCH 6/9] fix: address review comments --- driver/config/provider.go | 16 ---------------- oauth2/hook.go | 26 +++++++++++++------------- 2 files changed, 13 insertions(+), 29 deletions(-) diff --git a/driver/config/provider.go b/driver/config/provider.go index a2ae14f0422..ee1acd5ae0f 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -424,34 +424,18 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou } func (p *DefaultProvider) AuthorizationCodeHookURL(ctx context.Context) *url.URL { - if len(p.getProvider(ctx).String(KeyAuthorizationCodeHookURL)) == 0 { - return nil - } - return p.getProvider(ctx).RequestURIF(KeyAuthorizationCodeHookURL, nil) } func (p *DefaultProvider) ClientCredentialsHookURL(ctx context.Context) *url.URL { - if len(p.getProvider(ctx).String(KeyClientCredentialsHookURL)) == 0 { - return nil - } - return p.getProvider(ctx).RequestURIF(KeyClientCredentialsHookURL, nil) } func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL { - if len(p.getProvider(ctx).String(KeyRefreshTokenHookURL)) == 0 { - return nil - } - return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil) } func (p *DefaultProvider) JWTBearerRefreshHookURL(ctx context.Context) *url.URL { - if len(p.getProvider(ctx).String(KeyJWTBearerHookURL)) == 0 { - return nil - } - return p.getProvider(ctx).RequestURIF(KeyJWTBearerHookURL, nil) } diff --git a/oauth2/hook.go b/oauth2/hook.go index 9cef02f430b..57cff80b4b9 100644 --- a/oauth2/hook.go +++ b/oauth2/hook.go @@ -126,8 +126,8 @@ func JWTBearerHook(reg interface { } } -func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, hookType string, hookURL *url.URL) error { - if !requester.GetGrantTypes().ExactOne(hookType) { +func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, grantType string, hookURL *url.URL) error { + if !requester.GetGrantTypes().ExactOne(grantType) { return nil } @@ -138,7 +138,7 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac payload := map[string][]string{} - if requester.GetGrantTypes().ExactOne("urn:ietf:params:oauth:grant-type:jwt-bearer") || requester.GetGrantTypes().ExactOne("client_credentials") { + if grantType == "urn:ietf:params:oauth:grant-type:jwt-bearer" || grantType == "client_credentials" { payload = requester.GetRequestForm() } @@ -163,8 +163,8 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac return errorsx.WithStack( fosite.ErrServerError. WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while encoding the %s hook.", hookType)). - WithDebugf("Unable to encode the %v hook body: %s", hookType, err), + WithDescription(fmt.Sprintf("An error occurred while encoding the %s hook.", grantType)). + WithDebugf("Unable to encode the %v hook body: %s", grantType, err), ) } @@ -173,7 +173,7 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac return errorsx.WithStack( fosite.ErrServerError. WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while preparing the %s hook.", hookType)). + WithDescription(fmt.Sprintf("An error occurred while preparing the %s hook.", grantType)). WithDebugf("Unable to prepare the HTTP Request: %s", err), ) } @@ -184,7 +184,7 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac return errorsx.WithStack( fosite.ErrServerError. WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while executing the %s hook.", hookType)). + WithDescription(fmt.Sprintf("An error occurred while executing the %s hook.", grantType)). WithDebugf("Unable to execute HTTP Request: %s", err), ) } @@ -199,14 +199,14 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac case http.StatusForbidden: return errorsx.WithStack( fosite.ErrAccessDenied. - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). - WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", hookType, resp.Status)), + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). + WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", grantType, resp.Status)), ) default: return errorsx.WithStack( fosite.ErrServerError. - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). - WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", hookType, resp.Status)), + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). + WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", grantType, resp.Status)), ) } @@ -215,8 +215,8 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac return errorsx.WithStack( fosite.ErrServerError. WithWrap(err). - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", hookType)). - WithDebugf(fmt.Sprintf("Response from %s hook could not be decoded: %s", hookType, err)), + WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). + WithDebugf(fmt.Sprintf("Response from %s hook could not be decoded: %s", grantType, err)), ) } From 272b1a756f84f566b9fb7fd08adc2c60d55a4cb9 Mon Sep 17 00:00:00 2001 From: sgal Date: Thu, 23 Mar 2023 21:28:18 +0100 Subject: [PATCH 7/9] fix: strip sensitive values from the hook payload --- ...call_refresh_token_hook_if_configured.json | 7 -- ...call_refresh_token_hook_if_configured.json | 7 -- ...call_refresh_token_hook_if_configured.json | 7 -- ...call_refresh_token_hook_if_configured.json | 7 -- ...call_refresh_token_hook_if_configured.json | 7 -- ...call_refresh_token_hook_if_configured.json | 7 -- oauth2/hook.go | 24 +++--- oauth2/oauth2_auth_code_test.go | 2 - oauth2/oauth2_client_credentials_test.go | 4 +- oauth2/oauth2_jwt_bearer_test.go | 81 ++++++++++++++++++- 10 files changed, 91 insertions(+), 62 deletions(-) diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json index d9b61fd7865..e50624522ac 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json @@ -32,13 +32,6 @@ "allowed_top_level_claims": [] }, "requester": { - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [], "grant_types": [ "refresh_token" ], diff --git a/oauth2/hook.go b/oauth2/hook.go index 57cff80b4b9..69e47e58941 100644 --- a/oauth2/hook.go +++ b/oauth2/hook.go @@ -28,12 +28,6 @@ type AccessRequestHook func(ctx context.Context, requester fosite.AccessRequeste // // swagger:ignore type Requester struct { - // ClientID is the identifier of the OAuth 2.0 client. - ClientID string `json:"client_id"` - // GrantedScopes is the list of scopes granted to the OAuth 2.0 client. - GrantedScopes []string `json:"granted_scopes"` - // GrantedAudience is the list of audiences granted to the OAuth 2.0 client. - GrantedAudience []string `json:"granted_audience"` // GrantTypes is the requests grant types. GrantTypes []string `json:"grant_types"` // Payload is the requests payload. @@ -126,6 +120,13 @@ func JWTBearerHook(reg interface { } } +func getSafePayload(requester fosite.AccessRequester) url.Values { + payload := requester.GetRequestForm() + payload.Del("client_secret") + + return payload +} + func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, grantType string, hookURL *url.URL) error { if !requester.GetGrantTypes().ExactOne(grantType) { return nil @@ -138,16 +139,13 @@ func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.Ac payload := map[string][]string{} - if grantType == "urn:ietf:params:oauth:grant-type:jwt-bearer" || grantType == "client_credentials" { - payload = requester.GetRequestForm() + if grantType == "urn:ietf:params:oauth:grant-type:jwt-bearer" { + payload = getSafePayload(requester) } requesterInfo := Requester{ - ClientID: requester.GetClient().GetID(), - GrantedScopes: requester.GetGrantedScopes(), - GrantedAudience: requester.GetGrantedAudience(), - GrantTypes: requester.GetGrantTypes(), - Payload: payload, + GrantTypes: requester.GetGrantTypes(), + Payload: payload, } reqBody := TokenHookRequest{ diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index c748d10c299..85008bb0c5c 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -1227,8 +1227,6 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) except := []string{ diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 90bdbce36e3..9c62306e15d 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -261,7 +261,6 @@ func TestClientCredentials(t *testing.T) { expectedGrantedScopes := []string{"foobar"} expectedGrantedAudience := []string{"https://api.ory.sh/"} - expectedPayload := map[string][]string(map[string][]string{"audience": audience, "grant_type": {"client_credentials"}, "scope": {scope}}) var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) @@ -270,8 +269,7 @@ func TestClientCredentials(t *testing.T) { require.NotEmpty(t, hookReq.Session) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) require.NotEmpty(t, hookReq.Requester) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) - require.Equal(t, hookReq.Requester.Payload, expectedPayload) + require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) claims := map[string]interface{}{ "hooked": true, diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index f3006b28ff0..79b2e682308 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -64,7 +64,9 @@ func TestJWTBearer(t *testing.T) { } var getToken = func(t *testing.T, conf *clientcredentials.Config) (*goauth2.Token, error) { - conf.AuthStyle = goauth2.AuthStyleInHeader + if conf.AuthStyle == goauth2.AuthStyleAutoDetect { + conf.AuthStyle = goauth2.AuthStyleInHeader + } return conf.Token(context.Background()) } @@ -333,7 +335,6 @@ func TestJWTBearer(t *testing.T) { require.NotEmpty(t, hookReq.Session) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) require.NotEmpty(t, hookReq.Requester) - require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) require.Equal(t, hookReq.Requester.Payload, expectedPayload) claims := map[string]interface{}{ @@ -371,6 +372,82 @@ func TestJWTBearer(t *testing.T) { t.Run("strategy=jwt", run("jwt")) }) + t.Run("should call token hook if configured and omit client_secret from payload", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + audience := reg.Config().OAuth2TokenURL(ctx).String() + grantType := "urn:ietf:params:oauth:grant-type:jwt-bearer" + + token, _, err := signer.Generate(ctx, jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": trustGrant.Issuer, + "sub": trustGrant.Subject, + "aud": audience, + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + }, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}) + require.NoError(t, err) + + client := &hc.Client{ + Secret: secret, + GrantTypes: []string{"urn:ietf:params:oauth:grant-type:jwt-bearer"}, + Scope: "offline_access", + TokenEndpointAuthMethod: "client_secret_post", + } + require.NoError(t, reg.ClientManager().CreateClient(ctx, client)) + + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{client.Scope} + expectedGrantedAudience := []string{audience} + expectedPayload := map[string][]string(map[string][]string{"client_id": {client.GetID()}, "grant_type": {grantType}, "assertion": {token}, "scope": {client.Scope}}) + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.Payload, expectedPayload) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + + conf := newConf(client) + conf.AuthStyle = goauth2.AuthStyleInParams + conf.EndpointParams = url.Values{"grant_type": {grantType}, "assertion": {token}} + + result, err := getToken(t, conf) + require.NoError(t, err) + + inspectToken(t, result, client, strategy, trustGrant, true) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + t.Run("should fail token if hook fails", func(t *testing.T) { run := func(strategy string) func(t *testing.T) { return func(t *testing.T) { From 8a37a90dda6b97c1a671e98d1de4cc146e86beba Mon Sep 17 00:00:00 2001 From: sgal Date: Fri, 24 Mar 2023 16:29:42 +0100 Subject: [PATCH 8/9] fix:make a generic token hook config that is run for all grant types --- driver/config/provider.go | 18 +- driver/registry_base.go | 4 +- ...token_hook_if_configured-hook=legacy.json} | 10 +- ...esh_token_hook_if_configured-hook=new.json | 48 ++ ...token_hook_if_configured-hook=legacy.json} | 10 +- ...esh_token_hook_if_configured-hook=new.json | 48 ++ ...token_hook_if_configured-hook=legacy.json} | 10 +- ...esh_token_hook_if_configured-hook=new.json | 48 ++ ...token_hook_if_configured-hook=legacy.json} | 10 +- ...esh_token_hook_if_configured-hook=new.json | 48 ++ ..._token_hook_if_configured-hook=legacy.json | 53 ++ ...esh_token_hook_if_configured-hook=new.json | 48 ++ ...call_refresh_token_hook_if_configured.json | 47 -- ..._token_hook_if_configured-hook=legacy.json | 53 ++ ...esh_token_hook_if_configured-hook=new.json | 48 ++ ...call_refresh_token_hook_if_configured.json | 47 -- oauth2/hook.go | 226 -------- oauth2/oauth2_auth_code_test.go | 545 +++++++----------- oauth2/oauth2_client_credentials_test.go | 24 +- oauth2/oauth2_jwt_bearer_test.go | 40 +- oauth2/refresh_hook.go | 102 ++++ oauth2/token_hook.go | 166 ++++++ spec/config.json | 18 +- 23 files changed, 946 insertions(+), 725 deletions(-) rename oauth2/.snapshots/{TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json => TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json} (85%) create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json rename oauth2/.snapshots/{TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json => TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json} (85%) create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json rename oauth2/.snapshots/{TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json => TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json} (85%) create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json rename oauth2/.snapshots/{TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json => TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json} (85%) create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json delete mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json create mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json delete mode 100644 oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json delete mode 100644 oauth2/hook.go create mode 100644 oauth2/refresh_hook.go create mode 100644 oauth2/token_hook.go diff --git a/driver/config/provider.go b/driver/config/provider.go index ee1acd5ae0f..31fe948cc9b 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -92,10 +92,8 @@ const ( KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" - KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101 - KeyAuthorizationCodeHookURL = "oauth2.authorization_code_hook" // #nosec G101 - KeyClientCredentialsHookURL = "oauth2.client_credentials_hook" // #nosec G101 - KeyJWTBearerHookURL = "oauth2.jwt_bearer_hook" // #nosec G101 + KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101 + KeyTokenHookURL = "oauth2.token_hook" // #nosec G101 KeyDevelopmentMode = "dev" ) @@ -423,22 +421,14 @@ func (p *DefaultProvider) AccessTokenStrategy(ctx context.Context, additionalSou return s } -func (p *DefaultProvider) AuthorizationCodeHookURL(ctx context.Context) *url.URL { - return p.getProvider(ctx).RequestURIF(KeyAuthorizationCodeHookURL, nil) -} - -func (p *DefaultProvider) ClientCredentialsHookURL(ctx context.Context) *url.URL { - return p.getProvider(ctx).RequestURIF(KeyClientCredentialsHookURL, nil) +func (p *DefaultProvider) TokenHookURL(ctx context.Context) *url.URL { + return p.getProvider(ctx).RequestURIF(KeyTokenHookURL, nil) } func (p *DefaultProvider) TokenRefreshHookURL(ctx context.Context) *url.URL { return p.getProvider(ctx).RequestURIF(KeyRefreshTokenHookURL, nil) } -func (p *DefaultProvider) JWTBearerRefreshHookURL(ctx context.Context) *url.URL { - return p.getProvider(ctx).RequestURIF(KeyJWTBearerHookURL, nil) -} - func (p *DefaultProvider) DbIgnoreUnknownTableColumns() bool { return p.p.Bool(KeyDBIgnoreUnknownTableColumns) } diff --git a/driver/registry_base.go b/driver/registry_base.go index b78ce21759c..d2d458427a8 100644 --- a/driver/registry_base.go +++ b/driver/registry_base.go @@ -506,9 +506,7 @@ func (m *RegistryBase) AccessRequestHooks() []oauth2.AccessRequestHook { if m.arhs == nil { m.arhs = []oauth2.AccessRequestHook{ oauth2.RefreshTokenHook(m), - oauth2.AuthorizationCodeHook(m), - oauth2.ClientCredentialsHook(m), - oauth2.JWTBearerHook(m), + oauth2.TokenHook(m), } } return m.arhs diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json similarity index 85% rename from oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json rename to oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json index e50624522ac..66fbfb5af98 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -32,10 +32,16 @@ "allowed_top_level_claims": [] }, "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], "grant_types": [ "refresh_token" - ], - "payload": {} + ] }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json similarity index 85% rename from oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json rename to oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json index e50624522ac..66fbfb5af98 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -32,10 +32,16 @@ "allowed_top_level_claims": [] }, "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], "grant_types": [ "refresh_token" - ], - "payload": {} + ] }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json similarity index 85% rename from oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json rename to oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json index e50624522ac..66fbfb5af98 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -32,10 +32,16 @@ "allowed_top_level_claims": [] }, "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], "grant_types": [ "refresh_token" - ], - "payload": {} + ] }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=jwt-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json similarity index 85% rename from oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json rename to oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json index e50624522ac..66fbfb5af98 100644 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured.json +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -32,10 +32,16 @@ "allowed_top_level_claims": [] }, "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], "grant_types": [ "refresh_token" - ], - "payload": {} + ] }, "client_id": "app-client", "granted_scopes": [ diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=0-description=should_pass_request_if_strategy_passes-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json new file mode 100644 index 00000000000..66fbfb5af98 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -0,0 +1,53 @@ +{ + "subject": "foo", + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "sid": "" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": {}, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ] + }, + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [] +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json deleted file mode 100644 index e50624522ac..00000000000 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=2-description=should_pass_because_prompt=none_and_max_age_is_less_than_auth_time-should_call_refresh_token_hook_if_configured.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "subject": "foo", - "session": { - "id_token": { - "id_token_claims": { - "jti": "", - "iss": "http://localhost:4444/", - "sub": "foo", - "aud": [ - "app-client" - ], - "nonce": "", - "at_hash": "", - "acr": "1", - "amr": null, - "c_hash": "", - "ext": { - "sid": "" - } - }, - "headers": { - "extra": { - } - }, - "username": "", - "subject": "foo" - }, - "extra": {}, - "client_id": "app-client", - "consent_challenge": "", - "exclude_not_before_claim": false, - "allowed_top_level_claims": [] - }, - "requester": { - "grant_types": [ - "refresh_token" - ], - "payload": {} - }, - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [] -} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json new file mode 100644 index 00000000000..66fbfb5af98 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=legacy.json @@ -0,0 +1,53 @@ +{ + "subject": "foo", + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "sid": "" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": {}, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "requester": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ] + }, + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [] +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json new file mode 100644 index 00000000000..28176f90826 --- /dev/null +++ b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured-hook=new.json @@ -0,0 +1,48 @@ +{ + "session": { + "id_token": { + "id_token_claims": { + "jti": "", + "iss": "http://localhost:4444/", + "sub": "foo", + "aud": [ + "app-client" + ], + "nonce": "", + "at_hash": "", + "acr": "1", + "amr": null, + "c_hash": "", + "ext": { + "hooked": "legacy" + } + }, + "headers": { + "extra": { + } + }, + "username": "", + "subject": "foo" + }, + "extra": { + "hooked": "legacy" + }, + "client_id": "app-client", + "consent_challenge": "", + "exclude_not_before_claim": false, + "allowed_top_level_claims": [] + }, + "request": { + "client_id": "app-client", + "granted_scopes": [ + "offline", + "openid", + "hydra.*" + ], + "granted_audience": [], + "grant_types": [ + "refresh_token" + ], + "payload": {} + } +} diff --git a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json b/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json deleted file mode 100644 index e50624522ac..00000000000 --- a/oauth2/.snapshots/TestAuthCodeWithMockStrategy-strategy=opaque-case=5-description=should_pass_with_prompt=login_when_authentication_time_is_recent-should_call_refresh_token_hook_if_configured.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "subject": "foo", - "session": { - "id_token": { - "id_token_claims": { - "jti": "", - "iss": "http://localhost:4444/", - "sub": "foo", - "aud": [ - "app-client" - ], - "nonce": "", - "at_hash": "", - "acr": "1", - "amr": null, - "c_hash": "", - "ext": { - "sid": "" - } - }, - "headers": { - "extra": { - } - }, - "username": "", - "subject": "foo" - }, - "extra": {}, - "client_id": "app-client", - "consent_challenge": "", - "exclude_not_before_claim": false, - "allowed_top_level_claims": [] - }, - "requester": { - "grant_types": [ - "refresh_token" - ], - "payload": {} - }, - "client_id": "app-client", - "granted_scopes": [ - "offline", - "openid", - "hydra.*" - ], - "granted_audience": [] -} diff --git a/oauth2/hook.go b/oauth2/hook.go deleted file mode 100644 index 69e47e58941..00000000000 --- a/oauth2/hook.go +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright © 2022 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package oauth2 - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - - "github.com/hashicorp/go-retryablehttp" - - "github.com/ory/hydra/v2/x" - - "github.com/ory/fosite" - "github.com/ory/hydra/v2/consent" - "github.com/ory/hydra/v2/driver/config" - "github.com/ory/x/errorsx" -) - -// AccessRequestHook is called when an access token is being refreshed. -type AccessRequestHook func(ctx context.Context, requester fosite.AccessRequester) error - -// Requester is a token endpoint's request context. -// -// swagger:ignore -type Requester struct { - // GrantTypes is the requests grant types. - GrantTypes []string `json:"grant_types"` - // Payload is the requests payload. - Payload map[string][]string `json:"payload"` -} - -// TokenHookRequest is the request body sent to token hooks. -// -// swagger:ignore -type TokenHookRequest struct { - // Subject is the identifier of the authenticated end-user. - Subject string `json:"subject"` - // Session is the request's session.. - Session *Session `json:"session"` - // Requester is a token endpoint's request context. - Requester Requester `json:"requester"` - // ClientID is the identifier of the OAuth 2.0 client. - ClientID string `json:"client_id"` - // GrantedScopes is the list of scopes granted to the OAuth 2.0 client. - GrantedScopes []string `json:"granted_scopes"` - // GrantedAudience is the list of audiences granted to the OAuth 2.0 client. - GrantedAudience []string `json:"granted_audience"` -} - -// TokenHookResponse is the response body received from token hooks. -// -// swagger:ignore -type TokenHookResponse struct { - // Session is the session data returned by the hook. - Session consent.AcceptOAuth2ConsentRequestSession `json:"session"` -} - -// RefreshTokenHook is an AccessRequestHook called for `refresh_token` grant type. -func RefreshTokenHook(reg interface { - config.Provider - x.HTTPClientProvider -}, -) AccessRequestHook { - return func(ctx context.Context, requester fosite.AccessRequester) error { - hookURL := reg.Config().TokenRefreshHookURL(ctx) - if hookURL == nil { - return nil - } - return callHook(ctx, reg, requester, "refresh_token", hookURL) - } -} - -// AuthorizationCodeHook is an AccessRequestHook called for `authorization_code` grant type. -func AuthorizationCodeHook(reg interface { - config.Provider - x.HTTPClientProvider -}, -) AccessRequestHook { - return func(ctx context.Context, requester fosite.AccessRequester) error { - hookURL := reg.Config().AuthorizationCodeHookURL(ctx) - if hookURL == nil { - return nil - } - return callHook(ctx, reg, requester, "authorization_code", hookURL) - } -} - -// ClientCredentialsHook is an AccessRequestHook called for `client_credentials` grant type. -func ClientCredentialsHook(reg interface { - config.Provider - x.HTTPClientProvider -}, -) AccessRequestHook { - return func(ctx context.Context, requester fosite.AccessRequester) error { - hookURL := reg.Config().ClientCredentialsHookURL(ctx) - if hookURL == nil { - return nil - } - return callHook(ctx, reg, requester, "client_credentials", hookURL) - } -} - -// JWTBearerHook is an AccessRequestHook called for `urn:ietf:params:oauth:grant-type:jwt-bearer` grant type. -func JWTBearerHook(reg interface { - config.Provider - x.HTTPClientProvider -}, -) AccessRequestHook { - return func(ctx context.Context, requester fosite.AccessRequester) error { - hookURL := reg.Config().JWTBearerRefreshHookURL(ctx) - if hookURL == nil { - return nil - } - return callHook(ctx, reg, requester, "urn:ietf:params:oauth:grant-type:jwt-bearer", hookURL) - } -} - -func getSafePayload(requester fosite.AccessRequester) url.Values { - payload := requester.GetRequestForm() - payload.Del("client_secret") - - return payload -} - -func callHook(ctx context.Context, reg x.HTTPClientProvider, requester fosite.AccessRequester, grantType string, hookURL *url.URL) error { - if !requester.GetGrantTypes().ExactOne(grantType) { - return nil - } - - session, ok := requester.GetSession().(*Session) - if !ok { - return nil - } - - payload := map[string][]string{} - - if grantType == "urn:ietf:params:oauth:grant-type:jwt-bearer" { - payload = getSafePayload(requester) - } - - requesterInfo := Requester{ - GrantTypes: requester.GetGrantTypes(), - Payload: payload, - } - - reqBody := TokenHookRequest{ - Session: session, - Requester: requesterInfo, - Subject: session.GetSubject(), - ClientID: requester.GetClient().GetID(), - GrantedScopes: requester.GetGrantedScopes(), - GrantedAudience: requester.GetGrantedAudience(), - } - reqBodyBytes, err := json.Marshal(&reqBody) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while encoding the %s hook.", grantType)). - WithDebugf("Unable to encode the %v hook body: %s", grantType, err), - ) - } - - req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, hookURL.String(), bytes.NewReader(reqBodyBytes)) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while preparing the %s hook.", grantType)). - WithDebugf("Unable to prepare the HTTP Request: %s", err), - ) - } - req.Header.Set("Content-Type", "application/json; charset=UTF-8") - - resp, err := reg.HTTPClient(ctx).Do(req) - if err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription(fmt.Sprintf("An error occurred while executing the %s hook.", grantType)). - WithDebugf("Unable to execute HTTP Request: %s", err), - ) - } - defer resp.Body.Close() - - switch resp.StatusCode { - case http.StatusOK: - // Token permitted with new session data - case http.StatusNoContent: - // Token is permitted without overriding session data - return nil - case http.StatusForbidden: - return errorsx.WithStack( - fosite.ErrAccessDenied. - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). - WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", grantType, resp.Status)), - ) - default: - return errorsx.WithStack( - fosite.ErrServerError. - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). - WithDebugf(fmt.Sprintf("%s hook responded with HTTP status code: %s", grantType, resp.Status)), - ) - } - - var respBody TokenHookResponse - if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return errorsx.WithStack( - fosite.ErrServerError. - WithWrap(err). - WithDescription(fmt.Sprintf("The %s hook target responded with an error.", grantType)). - WithDebugf(fmt.Sprintf("Response from %s hook could not be decoded: %s", grantType, err)), - ) - } - - // Overwrite existing session data (extra claims). - session.Extra = respBody.Session.AccessToken - idTokenClaims := session.IDTokenClaims() - idTokenClaims.Extra = respBody.Session.IDToken - return nil -} diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 85008bb0c5c..8c6312f428a 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -235,7 +235,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, i.Get("active").Bool(), "%s", i) assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) return i } @@ -260,7 +260,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) - assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) + assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) assert.EqualValues(t, `["hydra","offline","openid"]`, i.Get("scp").Raw, "%s", i) return i } @@ -682,197 +682,6 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) } }) - - t.Run("case=add ext claims from hook if configured", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Extra, map[string]interface{}{"foo": "bar"}) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) - - claims := map[string]interface{}{ - "hooked": true, - } - - hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, - } - - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) - - defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) - - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - - token, err := conf.Exchange(context.Background(), code) - require.NoError(t, err) - - assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) - - // NOTE: using introspect to cover both jwt and opaque strategies - accessTokenClaims := introspectAccessToken(t, conf, token, subject) - require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) - - idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) - require.True(t, idTokenClaims.Get("hooked").Bool()) - } - } - - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) - - t.Run("case=fail token exchange if hook fails", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) - - defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) - - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } - - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) - - t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) - - defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) - - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } - - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) - - t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { - run := func(strategy string) func(t *testing.T) { - return func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() - - reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, hs.URL) - - defer reg.Config().MustSet(ctx, config.KeyAuthorizationCodeHookURL, nil) - - expectAud := "https://api.ory.sh/" - c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) - testhelpers.NewLoginConsentUI(t, reg.Config(), - acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { - assert.False(t, r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - return nil - }), - acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { - assert.False(t, *r.Skip) - assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) - })) - - code, _ := getAuthorizeCode(t, conf, nil, - oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), - oauth2.SetAuthURLParam("nonce", nonce)) - require.NotEmpty(t, code) - - _, err := conf.Exchange(context.Background(), code) - require.Error(t, err) - } - } - - t.Run("strategy=opaque", run("opaque")) - t.Run("strategy=jwt", run("jwt")) - }) } // TestAuthCodeWithMockStrategy runs the authorization_code flow against various ConsentStrategy scenarios. @@ -1210,156 +1019,226 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { }) t.Run("should call refresh token hook if configured", func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") - - expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} - expectedSubject := "foo" - - var hookReq hydraoauth2.TokenHookRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.Equal(t, hookReq.Subject, expectedSubject) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) - require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) - require.NotEmpty(t, hookReq.Session) - require.Equal(t, hookReq.Session.Subject, expectedSubject) - require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) - require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) - - except := []string{ - "session.kid", - "session.id_token.expires_at", - "session.id_token.headers.extra.kid", - "session.id_token.id_token_claims.iat", - "session.id_token.id_token_claims.exp", - "session.id_token.id_token_claims.rat", - "session.id_token.id_token_claims.auth_time", - } - snapshotx.SnapshotTExcept(t, hookReq, except) - - claims := map[string]interface{}{ - "hooked": true, - } - - hookResp := hydraoauth2.TokenHookResponse{ - Session: consent.AcceptOAuth2ConsentRequestSession{ - AccessToken: claims, - IDToken: claims, - }, + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{"openid", "offline", "hydra.*"} + expectedSubject := "foo" + + exceptKeys := []string{ + "session.kid", + "session.id_token.expires_at", + "session.id_token.headers.extra.kid", + "session.id_token.id_token_claims.iat", + "session.id_token.id_token_claims.exp", + "session.id_token.id_token_claims.rat", + "session.id_token.id_token_claims.auth_time", + } + + if hookType == "legacy" { + var hookReq hydraoauth2.RefreshTokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.Equal(t, hookReq.Subject, expectedSubject) + require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.GrantedAudience, []string{}) + require.Equal(t, hookReq.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Requester) + require.Equal(t, hookReq.Requester.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Requester.GrantedScopes, expectedGrantedScopes) + + snapshotx.SnapshotTExcept(t, hookReq, exceptKeys) + } else { + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Subject, expectedSubject) + require.Equal(t, hookReq.Session.ClientID, oauthConfig.ClientID) + require.NotEmpty(t, hookReq.Request) + require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + + snapshotx.SnapshotTExcept(t, hookReq, exceptKeys) + } + + claims := map[string]interface{}{ + "hooked": hookType, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHookURL, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + require.Equal(t, accessTokenClaims.Get("ext.hooked").String(), hookType) + + idTokenBody, err := x.DecodeSegment( + strings.Split( + gjson.GetBytes(body, "id_token").String(), + ".", + )[1], + ) + require.NoError(t, err) + + require.Equal(t, gjson.GetBytes(idTokenBody, "hooked").String(), hookType) } - - w.WriteHeader(http.StatusOK) - require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) - })) - defer hs.Close() - - conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) - - idTokenBody, err := x.DecodeSegment( - strings.Split( - gjson.GetBytes(body, "id_token").String(), - ".", - )[1], - ) - require.NoError(t, err) - - require.True(t, gjson.GetBytes(idTokenBody, "hooked").Bool()) + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) }) t.Run("should not override session data if token refresh hook returns no content", func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - })) - defer hs.Close() - - conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) - - origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - body, err = io.ReadAll(res.Body) - require.NoError(t, err) - - require.NoError(t, json.Unmarshal(body, &refreshedToken)) - - refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) - assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHookURL, nil) + } + + origAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + body, err = io.ReadAll(res.Body) + require.NoError(t, err) + + require.NoError(t, json.Unmarshal(body, &refreshedToken)) + + refreshedAccessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, refreshedToken.AccessToken, ts) + assertx.EqualAsJSONExcept(t, json.RawMessage(origAccessTokenClaims.Raw), json.RawMessage(refreshedAccessTokenClaims.Raw), []string{"exp", "iat", "nbf"}) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) }) - t.Run("should fail token refresh with `server_error` if hook fails", func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer hs.Close() - - conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "An error occurred while executing the refresh_token hook.", errBody.Description) + t.Run("should fail token refresh with `server_error` if refresh hook fails", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHookURL, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "An error occurred while executing the token hook.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) }) - t.Run("should fail token refresh with `access_denied` if hook denied the request", func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer hs.Close() - - conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusForbidden, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) - require.Equal(t, "The refresh_token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + t.Run("should fail token refresh with `access_denied` if legacy refresh hook denied the request", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHookURL, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusForbidden, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error. Make sure that the request you are making is valid. Maybe the credential or request parameters you are using are limited in scope or otherwise restricted.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) }) - t.Run("should fail token refresh with `server_error` if hook response is malformed", func(t *testing.T) { - hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer hs.Close() - - conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) - defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) - - res, err := testRefresh(t, &refreshedToken, ts.URL, false) - require.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, res.StatusCode) - - var errBody fosite.RFC6749ErrorJson - require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) - require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) - require.Equal(t, "The refresh_token hook target responded with an error.", errBody.Description) + t.Run("should fail token refresh with `server_error` if refresh hook response is malformed", func(t *testing.T) { + run := func(hookType string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + if hookType == "legacy" { + conf.MustSet(ctx, config.KeyRefreshTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyRefreshTokenHookURL, nil) + } else { + conf.MustSet(ctx, config.KeyTokenHookURL, hs.URL) + defer conf.MustSet(ctx, config.KeyTokenHookURL, nil) + } + + res, err := testRefresh(t, &refreshedToken, ts.URL, false) + require.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + + var errBody fosite.RFC6749ErrorJson + require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody)) + require.Equal(t, fosite.ErrServerError.Error(), errBody.Name) + require.Equal(t, "The token hook target responded with an error.", errBody.Description) + } + } + t.Run("hook=legacy", run("legacy")) + t.Run("hook=new", run("new")) }) t.Run("refreshing old token should no longer work", func(t *testing.T) { diff --git a/oauth2/oauth2_client_credentials_test.go b/oauth2/oauth2_client_credentials_test.go index 9c62306e15d..1703bda9c38 100644 --- a/oauth2/oauth2_client_credentials_test.go +++ b/oauth2/oauth2_client_credentials_test.go @@ -264,12 +264,12 @@ func TestClientCredentials(t *testing.T) { var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) require.NotEmpty(t, hookReq.Session) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.Payload, map[string][]string{}) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, expectedGrantedAudience) + require.Equal(t, hookReq.Request.Payload, map[string][]string{}) claims := map[string]interface{}{ "hooked": true, @@ -288,9 +288,9 @@ func TestClientCredentials(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) secret := uuid.New().String() cl, conf := newCustomClient(t, &hc.Client{ @@ -318,9 +318,9 @@ func TestClientCredentials(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) _, conf := newClient(t) @@ -342,9 +342,9 @@ func TestClientCredentials(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) _, conf := newClient(t) @@ -366,9 +366,9 @@ func TestClientCredentials(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyClientCredentialsHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) _, conf := newClient(t) diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index 79b2e682308..1aa1f8179ff 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -326,16 +326,16 @@ func TestJWTBearer(t *testing.T) { expectedGrantedScopes := []string{client.Scope} expectedGrantedAudience := []string{audience} - expectedPayload := map[string][]string(map[string][]string{"grant_type": {grantType}, "assertion": {token}, "scope": {client.Scope}}) + expectedPayload := map[string][]string(map[string][]string{"assertion": {token}}) var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) require.NotEmpty(t, hookReq.Session) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.Payload, expectedPayload) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, expectedGrantedAudience) + require.Equal(t, hookReq.Request.Payload, expectedPayload) claims := map[string]interface{}{ "hooked": true, @@ -354,9 +354,9 @@ func TestJWTBearer(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) conf := newConf(client) conf.EndpointParams = url.Values{"grant_type": {grantType}, "assertion": {token}} @@ -401,16 +401,16 @@ func TestJWTBearer(t *testing.T) { expectedGrantedScopes := []string{client.Scope} expectedGrantedAudience := []string{audience} - expectedPayload := map[string][]string(map[string][]string{"client_id": {client.GetID()}, "grant_type": {grantType}, "assertion": {token}, "scope": {client.Scope}}) + expectedPayload := map[string][]string(map[string][]string{"assertion": {token}}) var hookReq hydraoauth2.TokenHookRequest require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) - require.ElementsMatch(t, hookReq.GrantedScopes, expectedGrantedScopes) - require.ElementsMatch(t, hookReq.GrantedAudience, expectedGrantedAudience) require.NotEmpty(t, hookReq.Session) require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) - require.NotEmpty(t, hookReq.Requester) - require.Equal(t, hookReq.Requester.Payload, expectedPayload) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, expectedGrantedAudience) + require.Equal(t, hookReq.Request.Payload, expectedPayload) claims := map[string]interface{}{ "hooked": true, @@ -429,9 +429,9 @@ func TestJWTBearer(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) conf := newConf(client) conf.AuthStyle = goauth2.AuthStyleInParams @@ -457,9 +457,9 @@ func TestJWTBearer(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) token, _, err := signer.Generate(ctx, jwt.MapClaims{ "jti": uuid.NewString(), @@ -492,9 +492,9 @@ func TestJWTBearer(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) token, _, err := signer.Generate(ctx, jwt.MapClaims{ "jti": uuid.NewString(), @@ -527,9 +527,9 @@ func TestJWTBearer(t *testing.T) { defer hs.Close() reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) - reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, hs.URL) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) - defer reg.Config().MustSet(ctx, config.KeyJWTBearerHookURL, nil) + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) token, _, err := signer.Generate(ctx, jwt.MapClaims{ "jti": uuid.NewString(), diff --git a/oauth2/refresh_hook.go b/oauth2/refresh_hook.go new file mode 100644 index 00000000000..62997a2f0f5 --- /dev/null +++ b/oauth2/refresh_hook.go @@ -0,0 +1,102 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "context" + "encoding/json" + + "github.com/ory/hydra/v2/x" + "github.com/ory/x/errorsx" + + "github.com/ory/fosite" + "github.com/ory/hydra/v2/driver/config" +) + +// Requester is a token endpoint's request context. +// +// swagger:ignore +type Requester struct { + // ClientID is the identifier of the OAuth 2.0 client. + ClientID string `json:"client_id"` + // GrantedScopes is the list of scopes granted to the OAuth 2.0 client. + GrantedScopes []string `json:"granted_scopes"` + // GrantedAudience is the list of audiences granted to the OAuth 2.0 client. + GrantedAudience []string `json:"granted_audience"` + // GrantTypes is the requests grant types. + GrantTypes []string `json:"grant_types"` +} + +// RefreshTokenHookRequest is the request body sent to the refresh token hook. +// +// swagger:ignore +type RefreshTokenHookRequest struct { + // Subject is the identifier of the authenticated end-user. + Subject string `json:"subject"` + // Session is the request's session.. + Session *Session `json:"session"` + // Requester is a token endpoint's request context. + Requester Requester `json:"requester"` + // ClientID is the identifier of the OAuth 2.0 client. + ClientID string `json:"client_id"` + // GrantedScopes is the list of scopes granted to the OAuth 2.0 client. + GrantedScopes []string `json:"granted_scopes"` + // GrantedAudience is the list of audiences granted to the OAuth 2.0 client. + GrantedAudience []string `json:"granted_audience"` +} + +// RefreshTokenHook is an AccessRequestHook called for `refresh_token` grant type. +func RefreshTokenHook(reg interface { + config.Provider + x.HTTPClientProvider +}) AccessRequestHook { + return func(ctx context.Context, requester fosite.AccessRequester) error { + hookURL := reg.Config().TokenRefreshHookURL(ctx) + if hookURL == nil { + return nil + } + + if !requester.GetGrantTypes().ExactOne("refresh_token") { + return nil + } + + session, ok := requester.GetSession().(*Session) + if !ok { + return nil + } + + requesterInfo := Requester{ + ClientID: requester.GetClient().GetID(), + GrantedScopes: requester.GetGrantedScopes(), + GrantedAudience: requester.GetGrantedAudience(), + GrantTypes: requester.GetGrantTypes(), + } + + reqBody := RefreshTokenHookRequest{ + Session: session, + Requester: requesterInfo, + Subject: session.GetSubject(), + ClientID: requester.GetClient().GetID(), + GrantedScopes: requester.GetGrantedScopes(), + GrantedAudience: requester.GetGrantedAudience(), + } + + reqBodyBytes, err := json.Marshal(&reqBody) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription("An error occurred while encoding the token hook."). + WithDebugf("Unable to encode the token hook body: %s", err), + ) + } + + err = executeHookAndUpdateSession(ctx, reg, hookURL, reqBodyBytes, session) + if err != nil { + return err + } + + return nil + } +} diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go new file mode 100644 index 00000000000..f7ca4416a71 --- /dev/null +++ b/oauth2/token_hook.go @@ -0,0 +1,166 @@ +// Copyright © 2022 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package oauth2 + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + + "github.com/hashicorp/go-retryablehttp" + + "github.com/ory/hydra/v2/x" + + "github.com/ory/fosite" + "github.com/ory/hydra/v2/consent" + "github.com/ory/hydra/v2/driver/config" + "github.com/ory/x/errorsx" +) + +// AccessRequestHook is called when an access token request is performed. +type AccessRequestHook func(ctx context.Context, requester fosite.AccessRequester) error + +// Request is a token endpoint's request context. +// +// swagger:ignore +type Request struct { + // ClientID is the identifier of the OAuth 2.0 client. + ClientID string `json:"client_id"` + // GrantedScopes is the list of scopes granted to the OAuth 2.0 client. + GrantedScopes []string `json:"granted_scopes"` + // GrantedAudience is the list of audiences granted to the OAuth 2.0 client. + GrantedAudience []string `json:"granted_audience"` + // GrantTypes is the requests grant types. + GrantTypes []string `json:"grant_types"` + // Payload is the requests payload. + Payload map[string][]string `json:"payload"` +} + +// TokenHookRequest is the request body sent to the token hook. +// +// swagger:ignore +type TokenHookRequest struct { + // Session is the request's session.. + Session *Session `json:"session"` + // Requester is a token endpoint's request context. + Request Request `json:"request"` +} + +// TokenHookResponse is the response body received from the token hook. +// +// swagger:ignore +type TokenHookResponse struct { + // Session is the session data returned by the hook. + Session consent.AcceptOAuth2ConsentRequestSession `json:"session"` +} + +func executeHookAndUpdateSession(ctx context.Context, reg x.HTTPClientProvider, hookURL *url.URL, reqBodyBytes []byte, session *Session) error { + req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodPost, hookURL.String(), bytes.NewReader(reqBodyBytes)) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription("An error occurred while preparing the token hook."). + WithDebugf("Unable to prepare the HTTP Request: %s", err), + ) + } + req.Header.Set("Content-Type", "application/json; charset=UTF-8") + + resp, err := reg.HTTPClient(ctx).Do(req) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription("An error occurred while executing the token hook."). + WithDebugf("Unable to execute HTTP Request: %s", err), + ) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + // Token permitted with new session data + case http.StatusNoContent: + // Token is permitted without overriding session data + return nil + case http.StatusForbidden: + return errorsx.WithStack( + fosite.ErrAccessDenied. + WithDescription("The token hook target responded with an error."). + WithDebugf("Token hook responded with HTTP status code: %s", resp.Status), + ) + default: + return errorsx.WithStack( + fosite.ErrServerError. + WithDescription("The token hook target responded with an error."). + WithDebugf("Token hook responded with HTTP status code: %s", resp.Status), + ) + } + + var respBody TokenHookResponse + if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription("The token hook target responded with an error."). + WithDebugf("Response from token hook could not be decoded: %s", err), + ) + } + + // Overwrite existing session data (extra claims). + session.Extra = respBody.Session.AccessToken + idTokenClaims := session.IDTokenClaims() + idTokenClaims.Extra = respBody.Session.IDToken + return nil +} + +// TokenHook is an AccessRequestHook called for all grant types. +func TokenHook(reg interface { + config.Provider + x.HTTPClientProvider +}) AccessRequestHook { + return func(ctx context.Context, requester fosite.AccessRequester) error { + hookURL := reg.Config().TokenHookURL(ctx) + if hookURL == nil { + return nil + } + + session, ok := requester.GetSession().(*Session) + if !ok { + return nil + } + + request := Request{ + ClientID: requester.GetClient().GetID(), + GrantedScopes: requester.GetGrantedScopes(), + GrantedAudience: requester.GetGrantedAudience(), + GrantTypes: requester.GetGrantTypes(), + Payload: requester.Sanitize([]string{"assertion"}).GetRequestForm(), + } + + reqBody := TokenHookRequest{ + Session: session, + Request: request, + } + + reqBodyBytes, err := json.Marshal(&reqBody) + if err != nil { + return errorsx.WithStack( + fosite.ErrServerError. + WithWrap(err). + WithDescription("An error occurred while encoding the token hook."). + WithDebugf("Unable to encode the token hook body: %s", err), + ) + } + + err = executeHookAndUpdateSession(ctx, reg, hookURL, reqBodyBytes, session) + if err != nil { + return err + } + + return nil + } +} diff --git a/spec/config.json b/spec/config.json index d0db9412fa7..34c22e26d34 100644 --- a/spec/config.json +++ b/spec/config.json @@ -976,23 +976,11 @@ "format": "uri", "examples": ["https://my-example.app/token-refresh-hook"] }, - "authorization_code_hook": { + "token_hook": { "type": "string", - "description": "Sets the authorization code hook endpoint. If set it will be called while providing token to customize claims.", + "description": "Sets the token hook endpoint for all grant types. If set it will be called while providing token to customize claims.", "format": "uri", - "examples": ["https://my-example.app/token-authorization-code-hook"] - }, - "client_credentials_hook": { - "type": "string", - "description": "Sets the client credentials hook endpoint. If set it will be called while providing token to customize claims.", - "format": "uri", - "examples": ["https://my-example.app/token-client-credentials-hook"] - }, - "jwt_bearer_hook": { - "type": "string", - "description": "Sets the jwt bearer hook endpoint. If set it will be called while providing token to customize claims.", - "format": "uri", - "examples": ["https://my-example.app/token-jwt-bearer-hook"] + "examples": ["https://my-example.app/token-hook"] } } }, From d06cdf0ade293329f771071e9113aa556f55b816 Mon Sep 17 00:00:00 2001 From: sgal Date: Fri, 24 Mar 2023 16:42:51 +0100 Subject: [PATCH 9/9] fix:add hook tests for authorization code grant --- oauth2/oauth2_auth_code_test.go | 196 +++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 2 deletions(-) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 8c6312f428a..b4811fead10 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -235,7 +235,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, i.Get("active").Bool(), "%s", i) assert.EqualValues(t, conf.ClientID, i.Get("client_id").String(), "%s", i) assert.EqualValues(t, expectedSubject, i.Get("sub").String(), "%s", i) - assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) return i } @@ -260,7 +260,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.True(t, time.Now().After(time.Unix(i.Get("nbf").Int(), 0)), "%s", i) assert.True(t, time.Now().Before(time.Unix(i.Get("exp").Int(), 0)), "%s", i) requirex.EqualTime(t, expectedExp, time.Unix(i.Get("exp").Int(), 0), time.Second) - assert.EqualValues(t, `{"foo":"bar"}`, i.Get("ext").Raw, "%s", i) + assert.EqualValues(t, `bar`, i.Get("ext.foo").String(), "%s", i) assert.EqualValues(t, `["hydra","offline","openid"]`, i.Get("scp").Raw, "%s", i) return i } @@ -682,6 +682,197 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { assert.Empty(t, uiClaims.Get(f).Raw, "%s: %s", f, uiClaims) } }) + + t.Run("case=add ext claims from hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{"foo": "bar"}) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + require.Equal(t, hookReq.Request.Payload, map[string][]string{}) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: consent.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + token, err := conf.Exchange(context.Background(), code) + require.NoError(t, err) + + assertJWTAccessToken(t, strategy, conf, token, subject, time.Now().Add(reg.Config().GetAccessTokenLifespan(ctx))) + + // NOTE: using introspect to cover both jwt and opaque strategies + accessTokenClaims := introspectAccessToken(t, conf, token, subject) + require.True(t, accessTokenClaims.Get("ext.hooked").Bool()) + + idTokenClaims := assertIDToken(t, token, conf, subject, nonce, time.Now().Add(reg.Config().GetIDTokenLifespan(ctx))) + require.True(t, idTokenClaims.Get("hooked").Bool()) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook fails", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook denies the request", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("case=fail token exchange if hook response is malformed", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHookURL, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHookURL, nil) + + expectAud := "https://api.ory.sh/" + c, conf := newOAuth2Client(t, testhelpers.NewCallbackURL(t, "callback", testhelpers.HTTPServerNotImplementedHandler)) + testhelpers.NewLoginConsentUI(t, reg.Config(), + acceptLoginHandler(t, c, subject, func(r *hydra.OAuth2LoginRequest) *hydra.AcceptOAuth2LoginRequest { + assert.False(t, r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + return nil + }), + acceptConsentHandler(t, c, subject, func(r *hydra.OAuth2ConsentRequest) { + assert.False(t, *r.Skip) + assert.EqualValues(t, []string{expectAud}, r.RequestedAccessTokenAudience) + })) + + code, _ := getAuthorizeCode(t, conf, nil, + oauth2.SetAuthURLParam("audience", "https://api.ory.sh/"), + oauth2.SetAuthURLParam("nonce", nonce)) + require.NotEmpty(t, code) + + _, err := conf.Exchange(context.Background(), code) + require.Error(t, err) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) } // TestAuthCodeWithMockStrategy runs the authorization_code flow against various ConsentStrategy scenarios. @@ -1062,6 +1253,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) { require.Equal(t, hookReq.Request.ClientID, oauthConfig.ClientID) require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) require.ElementsMatch(t, hookReq.Request.GrantedAudience, []string{}) + require.Equal(t, hookReq.Request.Payload, map[string][]string{}) snapshotx.SnapshotTExcept(t, hookReq, exceptKeys) }