Skip to content

Commit

Permalink
Merge pull request #1694 from vmware-tanzu/cli_login_page_errors
Browse files Browse the repository at this point in the history
Same error messages shown in CLI's callback web page and in terminal
  • Loading branch information
benjaminapetersen authored Sep 26, 2023
2 parents e25ecea + cede640 commit d44882f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 31 deletions.
48 changes: 28 additions & 20 deletions pkg/oidcclient/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,16 +499,9 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
// validations on the returned ID token.
tokenCtx, tokenCtxCancelFunc := context.WithTimeout(h.ctx, httpRequestTimeout)
defer tokenCtxCancelFunc()
token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).
ExchangeAuthcodeAndValidateTokens(
tokenCtx,
authCode,
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
token, err := h.redeemAuthCode(tokenCtx, authCode)
if err != nil {
return nil, fmt.Errorf("error during authorization code exchange: %w", err)
return nil, fmt.Errorf("could not complete authorization code exchange: %w", err)
}

return token, nil
Expand Down Expand Up @@ -642,7 +635,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
return
}

// When a code is pasted, redeem it for a token and return that result on the callbacks channel.
// When a code is pasted, redeem it for a token and return the results on the callback channel.
token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err}
}()
Expand Down Expand Up @@ -849,11 +842,23 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true, false)
}

func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
// If we return an error, also report it back over the channel to the main CLI thread.
// handleAuthCodeCallback is used as an http handler, so it does not run in the CLI's main goroutine.
// Upon a callback redirect request from an identity provider, it uses a callback channel to communicate
// its results back to the main thread of the CLI. The result can contain either some tokens from the
// identity provider's token endpoint, or the result can contain an error. When the result is an error,
// the CLI's main goroutine is responsible for printing that error to the terminal. At the same time,
// this function serves a web response, and that web response is rendered in the user's browser. So the
// user has two places to look for error messages: in their browser and in the CLI's terminal. Ideally,
// these messages would be the same. Note that using httperr.Wrap will cause the details of the wrapped
// err to be printed by the CLI, but not printed in the browser due to the way that the httperr package
// works, so avoid using httperr.Wrap in this function.
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (returnedErr error) {
defer func() {
if err != nil {
h.callbacks <- callbackResult{err: err}
// If we returned an error, then also report it back over the channel to the main CLI goroutine.
// Because returnedErr is the named return value, inside this defer returnedErr will hold the value
// returned by any explicit return statement.
if returnedErr != nil {
h.callbacks <- callbackResult{err: returnedErr}
}
}()

Expand All @@ -867,9 +872,10 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}

// For POST and OPTIONS requests, calculate the allowed origin for CORS.
issuerURL, parseErr := url.Parse(h.issuer)
if parseErr != nil {
return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr)
issuerURL, err := url.Parse(h.issuer)
if err != nil {
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusInternalServerError, "invalid issuer url: %s", err.Error())
}
allowOrigin := issuerURL.Scheme + "://" + issuerURL.Host

Expand Down Expand Up @@ -902,8 +908,9 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} // Otherwise, this is a POST request...

// Parse and pull the response parameters from an application/x-www-form-urlencoded request body.
if err := r.ParseForm(); err != nil {
return httperr.Wrap(http.StatusBadRequest, "invalid form", err)
if err = r.ParseForm(); err != nil {
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusBadRequest, "invalid form: %s", err.Error())
}
params = r.Form

Expand Down Expand Up @@ -943,7 +950,8 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// validations on the returned ID token.
token, err := h.redeemAuthCode(r.Context(), params.Get("code"))
if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
// Avoid using httperr.Wrap because that would hide the details of err from the browser output.
return httperr.Newf(http.StatusBadRequest, "could not complete authorization code exchange: %s", err.Error())
}

h.callbacks <- callbackResult{token: token}
Expand Down
34 changes: 23 additions & 11 deletions pkg/oidcclient/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error during authorization code exchange: some authcode exchange or token validation error",
wantErr: "could not complete authorization code exchange: some authcode exchange or token validation error",
},
{
name: "successful ldap login with prompts for username and password",
Expand Down Expand Up @@ -2236,7 +2236,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{
name: "invalid code",
query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: some exchange error",
wantErr: "could not complete authorization code exchange: some exchange error",
wantHeaders: map[string][]string{},
wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option {
Expand Down Expand Up @@ -2362,14 +2362,25 @@ func TestHandleAuthCodeCallback(t *testing.T) {
err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
if tt.wantHTTPStatus != 0 {
rec := httptest.NewRecorder()
err.(httperr.Responder).Respond(rec)
require.Equal(t, tt.wantHTTPStatus, rec.Code)
}
rec := httptest.NewRecorder()
err.(httperr.Responder).Respond(rec)
require.Equal(t, tt.wantHTTPStatus, rec.Code)
// The error message returned (to be shown by the CLI) and the error message shown in the resulting
// web page should always be the same.
require.Equal(t, http.StatusText(tt.wantHTTPStatus)+": "+tt.wantErr+"\n", rec.Body.String())
} else {
require.NoError(t, err)
require.Equal(t, tt.wantHTTPStatus, resp.Code)
switch {
case tt.wantNoCallbacks:
// When we return an error but keep listening, then we don't need a response body.
require.Empty(t, resp.Body)
case tt.wantHTTPStatus == http.StatusOK:
// When the login succeeds, the response body should show the success message.
require.Equal(t, "you have been logged in and may now close this tab", resp.Body.String())
default:
t.Fatal("test author made a mistake by expecting a non-200 response code without a wantErr")
}
}

if tt.wantHeaders != nil {
Expand All @@ -2385,11 +2396,12 @@ func TestHandleAuthCodeCallback(t *testing.T) {
case result := <-h.callbacks:
if tt.wantErr != "" {
require.EqualError(t, result.err, tt.wantErr)
return
require.Nil(t, result.token)
} else {
require.NoError(t, result.err)
require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, "test-id-token")
}
require.NoError(t, result.err)
require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, "test-id-token")
gotCallback = true
}
require.Equal(t, tt.wantNoCallbacks, !gotCallback)
Expand Down

0 comments on commit d44882f

Please sign in to comment.