Skip to content

Commit

Permalink
Fix races in login_test.go units tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cfryanr committed Feb 27, 2024
1 parent d888833 commit e0d92a5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
4 changes: 4 additions & 0 deletions pkg/oidcclient/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,9 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}

// promptForWebLogin prints a login URL to the screen, if needed. It will also print the "paste yor authorization code"
// prompt to the screen and wait for user input, if needed. It can be cancelled by the context provided.
// It returns a function which should be invoked by the caller to perform some cleanup.
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, printAuthorizeURL bool) func() {
if printAuthorizeURL {
_, _ = fmt.Fprintf(h.out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
Expand Down Expand Up @@ -714,6 +717,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
}

// promptForValue interactively prompts the user for a plaintext value and reads their input.
// If the context is canceled, it will return an error immediately.
// This can be replaced by a mock implementation for unit tests.
func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) {
Expand Down
44 changes: 36 additions & 8 deletions pkg/oidcclient/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.stdinIsTTY = func() bool { return true }
h.stderrIsTTY = func() bool { return true }

// Because response_mode=form_post, the Login function is going to prompt the user
// to paste their authcode. This test needs to handle that prompt.
h.promptForValue = func(ctx context.Context, promptLabel string, _ io.Writer) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
// This test does not want to simulate the user entering their authcode at the prompt,
// nor does it want to simulate a prompt error, so this function should hang as if
// we are waiting for user input. Otherwise, promptForWebLogin would be racing to
// write the result of this function to the callback chan (versus this test trying
// to write its own callbackResult to the same chan).
// The context passed into this function should be cancelled by the caller when it
// has received the authcode callback because the caller is no longer interested in
// waiting for the prompt anymore at that point, so this function can finish when
// the context is cancelled.
select {
case <-ctx.Done():
return "", errors.New("this error should be ignored by the caller because the context is already cancelled")
}
}

cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: formPostSuccessServer.URL,
Expand Down Expand Up @@ -2467,6 +2486,8 @@ func withOutWriter(t *testing.T, out io.Writer) Option {
func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
const testAuthURL = "https://test-authorize-url/"
const cancelledAuthcodePromptOutput = "[...]\n"
const newlineAfterEveryAuthcodePromptOutput = "\n"

expectedAuthURLOutput := func(expectedAuthURL string) string {
return fmt.Sprintf("Log in by visiting this link:\n\n %s\n\n", expectedAuthURL)
Expand Down Expand Up @@ -2522,7 +2543,7 @@ func TestHandlePasteCallback(t *testing.T) {
},
authorizeURL: testAuthURL,
printAuthorizeURL: true,
wantStderr: expectedAuthURLOutput(testAuthURL),
wantStderr: expectedAuthURLOutput(testAuthURL) + cancelledAuthcodePromptOutput + newlineAfterEveryAuthcodePromptOutput,
wantCallback: &callbackResult{
err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"),
},
Expand All @@ -2549,7 +2570,7 @@ func TestHandlePasteCallback(t *testing.T) {
},
authorizeURL: testAuthURL,
printAuthorizeURL: true,
wantStderr: expectedAuthURLOutput(testAuthURL),
wantStderr: expectedAuthURLOutput(testAuthURL) + newlineAfterEveryAuthcodePromptOutput,
wantCallback: &callbackResult{
err: fmt.Errorf("some exchange error"),
},
Expand All @@ -2576,7 +2597,7 @@ func TestHandlePasteCallback(t *testing.T) {
},
authorizeURL: testAuthURL,
printAuthorizeURL: true,
wantStderr: expectedAuthURLOutput(testAuthURL),
wantStderr: expectedAuthURLOutput(testAuthURL) + newlineAfterEveryAuthcodePromptOutput,
wantCallback: &callbackResult{
token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
},
Expand All @@ -2602,8 +2623,8 @@ func TestHandlePasteCallback(t *testing.T) {
}
},
authorizeURL: testAuthURL,
printAuthorizeURL: false, // do not want to print auth URL
wantStderr: "", // auth URL was not printed to stdout
printAuthorizeURL: false, // do not want to print auth URL
wantStderr: newlineAfterEveryAuthcodePromptOutput, // auth URL was not printed to stdout
wantCallback: &callbackResult{
token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
},
Expand All @@ -2628,9 +2649,7 @@ func TestHandlePasteCallback(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

h.promptForWebLogin(ctx, tt.authorizeURL, tt.printAuthorizeURL)

require.Equal(t, tt.wantStderr, buf.String())
cleanupPrompt := h.promptForWebLogin(ctx, tt.authorizeURL, tt.printAuthorizeURL)

if tt.wantCallback != nil {
select {
Expand All @@ -2640,6 +2659,15 @@ func TestHandlePasteCallback(t *testing.T) {
require.Equal(t, *tt.wantCallback, result)
}
}

// Reading buf before the goroutine inside of promptForWebLogin finishes is a data race,
// because that goroutine will also try to write to buf.
// Avoid this by shutting down its goroutine by cancelling its context,
// and clean it up with its cleanup function (which waits for it to be done).
// Then it should always be safe to read buf.
cancel()
cleanupPrompt()
require.Equal(t, tt.wantStderr, buf.String())
})
}
}
Expand Down

0 comments on commit e0d92a5

Please sign in to comment.