diff --git a/ssh/client_auth.go b/ssh/client_auth.go index b93961010d..b86dde151d 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -555,6 +555,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe } gotMsgExtInfo := false + gotUserAuthInfoRequest := false for { packet, err := c.readPacket() if err != nil { @@ -585,6 +586,9 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe if msg.PartialSuccess { return authPartialSuccess, msg.Methods, nil } + if !gotUserAuthInfoRequest { + return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } return authFailure, msg.Methods, nil case msgUserAuthSuccess: return authSuccess, nil, nil @@ -596,6 +600,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe if err := Unmarshal(packet, &msg); err != nil { return authFailure, nil, err } + gotUserAuthInfoRequest = true // Manually unpack the prompt/echo pairs. rest := msg.Prompts diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go index e981cc49a6..f11eeb590b 100644 --- a/ssh/client_auth_test.go +++ b/ssh/client_auth_test.go @@ -1293,3 +1293,92 @@ func TestCertAuthOpenSSHCompat(t *testing.T) { t.Fatalf("unable to dial remote side: %s", err) } } + +func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) { + const maxAuthTries = 2 + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // Start testserver + serverConfig := &ServerConfig{ + MaxAuthTries: maxAuthTries, + KeyboardInteractiveCallback: func(c ConnMetadata, + client KeyboardInteractiveChallenge) (*Permissions, error) { + // Fail keyboard-interactive authentication early before + // any prompt is sent to client. + return nil, errors.New("keyboard-interactive auth failed") + }, + PasswordCallback: func(c ConnMetadata, + pass []byte) (*Permissions, error) { + if string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverDone := make(chan struct{}) + go func() { + defer func() { serverDone <- struct{}{} }() + conn, chans, reqs, err := NewServerConn(c2, serverConfig) + if err != nil { + return + } + _ = conn.Close() + + discarderDone := make(chan struct{}) + go func() { + defer func() { discarderDone <- struct{}{} }() + DiscardRequests(reqs) + }() + for newChannel := range chans { + newChannel.Reject(Prohibited, + "testserver not accepting requests") + } + + <-discarderDone + }() + + // Connect to testserver, expect KeyboardInteractive() to be not called, + // PasswordCallback() to be called and connection to succeed. + passwordCallbackCalled := false + clientConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractive(func(name, + instruction string, questions []string, + echos []bool) ([]string, error) { + t.Errorf("unexpected call to KeyboardInteractive()") + return []string{clientPassword}, nil + }), maxAuthTries), + RetryableAuthMethod(PasswordCallback(func() (secret string, + err error) { + t.Logf("PasswordCallback()") + passwordCallbackCalled = true + return clientPassword, nil + }), maxAuthTries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, _, _, err := NewClientConn(c1, "", clientConfig) + if err != nil { + t.Errorf("unexpected NewClientConn() error: %v", err) + } + if conn != nil { + conn.Close() + } + + // Wait for server to finish. + <-serverDone + + if !passwordCallbackCalled { + t.Errorf("expected PasswordCallback() to be called") + } +}