diff --git a/auth.go b/auth.go new file mode 100644 index 000000000..de2e6eba7 --- /dev/null +++ b/auth.go @@ -0,0 +1,127 @@ +package mysql + +import "bytes" + +const mysqlClearPassword = "mysql_clear_password" +const mysqlNativePassword = "mysql_native_password" +const mysqlOldPassword = "mysql_old_password" +const defaultAuthPluginName = mysqlNativePassword + +var authPluginFactories map[string]func(*Config) AuthPlugin + +func init() { + authPluginFactories = make(map[string]func(*Config) AuthPlugin) + authPluginFactories[mysqlClearPassword] = func(cfg *Config) AuthPlugin { + return &clearTextPlugin{cfg} + } + authPluginFactories[mysqlNativePassword] = func(cfg *Config) AuthPlugin { + return &nativePasswordPlugin{cfg} + } + authPluginFactories[mysqlOldPassword] = func(cfg *Config) AuthPlugin { + return &oldPasswordPlugin{cfg} + } +} + +// RegisterAuthPlugin registers an authentication plugin to be used during +// negotiation with the server. If a plugin with the given name already exists, +// it will be overwritten. +func RegisterAuthPlugin(name string, factory func(*Config) AuthPlugin) { + authPluginFactories[name] = factory +} + +// AuthPlugin handles authenticating a user. +type AuthPlugin interface { + // Next takes a server's challenge and returns + // the bytes to send back or an error. + Next(challenge []byte) ([]byte, error) +} + +type clearTextPlugin struct { + cfg *Config +} + +func (p *clearTextPlugin) Next(challenge []byte) ([]byte, error) { + if !p.cfg.AllowCleartextPasswords { + return nil, ErrCleartextPassword + } + + // \0-terminated + return append([]byte(p.cfg.Passwd), 0), nil +} + +type nativePasswordPlugin struct { + cfg *Config +} + +func (p *nativePasswordPlugin) Next(challenge []byte) ([]byte, error) { + // NOTE: this seems to always be disabled... + // if !p.cfg.AllowNativePasswords { + // return nil, ErrNativePassword + // } + + return scramblePassword(challenge, []byte(p.cfg.Passwd)), nil +} + +type oldPasswordPlugin struct { + cfg *Config +} + +func (p *oldPasswordPlugin) Next(challenge []byte) ([]byte, error) { + if !p.cfg.AllowOldPasswords { + return nil, ErrOldPassword + } + + // \0-terminated + return append(scrambleOldPassword(challenge, []byte(p.cfg.Passwd)), 0), nil +} + +func handleAuthResult(mc *mysqlConn, plugin AuthPlugin, oldCipher []byte) error { + data, err := mc.readPacket() + if err != nil { + return err + } + + var authData []byte + + // packet indicator + switch data[0] { + case iOK: + return mc.handleOkPacket(data) + + case iEOF: // auth switch + if len(data) > 1 { + pluginEndIndex := bytes.IndexByte(data, 0x00) + pluginName := string(data[1:pluginEndIndex]) + if apf, ok := authPluginFactories[pluginName]; ok { + plugin = apf(mc.cfg) + } else { + return ErrUnknownPlugin + } + + if len(data) > pluginEndIndex+1 { + authData = data[pluginEndIndex+1 : len(data)-1] + } + } else { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + plugin = authPluginFactories[mysqlOldPassword](mc.cfg) + authData = oldCipher + } + case iAuthContinue: + // continue packet for a plugin. + authData = data[1:] // strip off the continue flag + default: // Error otherwise + return mc.handleErrorPacket(data) + } + + authData, err = plugin.Next(authData) + if err != nil { + return err + } + + err = mc.writeAuthDataPacket(authData) + if err != nil { + return err + } + + return handleAuthResult(mc, plugin, authData) +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 000000000..560abf9b2 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,72 @@ +package mysql + +import "testing" +import "bytes" + +func TestAuthPlugin_Cleartext(t *testing.T) { + cfg := &Config{ + Passwd: "funny", + } + + plugin := authPluginFactories[mysqlClearPassword](cfg) + + _, err := plugin.Next(nil) + if err == nil { + t.Fatalf("expected error when AllowCleartextPasswords is false") + } + + cfg.AllowCleartextPasswords = true + + actual, err := plugin.Next(nil) + if err != nil { + t.Fatalf("expected no error but got: %s", err) + } + + expected := append([]byte("funny"), 0) + if bytes.Compare(actual, expected) != 0 { + t.Fatalf("expected data to be %v, but got: %v", expected, actual) + } +} + +func TestAuthPlugin_NativePassword(t *testing.T) { + cfg := &Config{ + Passwd: "pass ", + } + + plugin := authPluginFactories[mysqlNativePassword](cfg) + + actual, err := plugin.Next([]byte{9, 8, 7, 6, 5, 4, 3, 2}) + if err != nil { + t.Fatalf("expected no error but got: %s", err) + } + + expected := []byte{195, 146, 3, 213, 111, 95, 252, 192, 97, 226, 173, 176, 91, 175, 131, 138, 89, 45, 75, 179} + if bytes.Compare(actual, expected) != 0 { + t.Fatalf("expected data to be %v, but got: %v", expected, actual) + } +} + +func TestAuthPlugin_OldPassword(t *testing.T) { + cfg := &Config{ + Passwd: "pass ", + } + + plugin := authPluginFactories[mysqlOldPassword](cfg) + + _, err := plugin.Next(nil) + if err == nil { + t.Fatalf("expected error when AllowOldPasswords is false") + } + + cfg.AllowOldPasswords = true + + actual, err := plugin.Next([]byte{9, 8, 7, 6, 5, 4, 3, 2}) + if err != nil { + t.Fatalf("expected no error but got: %s", err) + } + + expected := []byte{71, 87, 92, 90, 67, 91, 66, 81, 0} + if bytes.Compare(actual, expected) != 0 { + t.Fatalf("expected data to be %v, but got: %v", expected, actual) + } +} diff --git a/const.go b/const.go index 88cfff3fd..7a9353d8a 100644 --- a/const.go +++ b/const.go @@ -18,10 +18,11 @@ const ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( - iOK byte = 0x00 - iLocalInFile byte = 0xfb - iEOF byte = 0xfe - iERR byte = 0xff + iOK byte = 0x00 + iAuthContinue byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags diff --git a/driver.go b/driver.go index 0022d1f1e..9b5f01d6b 100644 --- a/driver.go +++ b/driver.go @@ -88,20 +88,50 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() + authPluginName, authData, err := mc.readInitPacket() if err != nil { mc.cleanup() return nil, err } + // save the old auth data in case the server + // needs to use the old password scheme. + oldCipher := make([]byte, len(authData)) + copy(oldCipher, authData) + + // Handle pluggable authentication + if authPluginName == "" { + // assume that without a name, we are using + // the default. + authPluginName = defaultAuthPluginName + } + + var authPlugin AuthPlugin + if apf, ok := authPluginFactories[authPluginName]; ok { + authPlugin = apf(mc.cfg) + authData, err = authPlugin.Next(authData) + if err != nil { + return nil, err + } + } else { + // we'll tell the server in response that we are switching to our + // default plugin because we didn't recognize the one they sent us. + authPluginName = defaultAuthPluginName + authPlugin = authPluginFactories[authPluginName](mc.cfg) + + // zero-out the authData because the current authData was for + // a plugin we don't know about. + authData = make([]byte, 0) + } + // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + if err = mc.writeAuthPacket(authPluginName, authData); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = handleAuthResult(mc, cipher); err != nil { + if err = handleAuthResult(mc, authPlugin, oldCipher); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. @@ -134,50 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return mc, nil } -func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { - // Read Result Packet - cipher, err := mc.readResultOK() - if err == nil { - return nil // auth successful - } - - if mc.cfg == nil { - return err // auth failed and retry not possible - } - - // Retry auth if configured to do so. - if mc.cfg.AllowOldPasswords && err == ErrOldPassword { - // Retry with old authentication method. Note: there are edge cases - // where this should work but doesn't; this is currently "wontfix": - // https://github.com/go-sql-driver/mysql/issues/184 - - // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is - // sent and we have to keep using the cipher sent in the init packet. - if cipher == nil { - cipher = oldCipher - } - - if err = mc.writeOldAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { - // Retry with clear text password for - // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html - // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } - return err -} - func init() { sql.Register("mysql", &MySQLDriver{}) } diff --git a/packets.go b/packets.go index aafe9793e..2c73573e9 100644 --- a/packets.go +++ b/packets.go @@ -139,19 +139,19 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { +func (mc *mysqlConn) readInitPacket() (string, []byte, error) { data, err := mc.readPacket() if err != nil { - return nil, err + return "", nil, err } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return "", nil, mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( + return "", nil, fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -162,8 +162,8 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // connection id [4 bytes] pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 - // first part of the password cipher [8 bytes] - cipher := data[pos : pos+8] + // first part of the auth data [8 bytes] + authData := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -171,10 +171,10 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return "", nil, ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + return "", nil, ErrNoTLS } pos += 2 @@ -198,7 +198,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. - cipher = append(cipher, data[pos:pos+12]...) + authData = append(authData, data[pos:pos+12]...) // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) @@ -209,21 +209,29 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { //} //return ErrMalformPkt - // make a memory safe copy of the cipher slice + pos += 13 + + var authPluginName string + if len(data) > pos { + // auth-plugin name (string[NUL]) + authPluginName = string(data[pos : len(data)-1]) + } + + // make a memory safe copy of the authData slice var b [20]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return authPluginName, b[:], nil } - // make a memory safe copy of the cipher slice + // make a memory safe copy of the authData slice var b [8]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return "", b[:], nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(authPluginName string, authData []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -247,10 +255,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientMultiStatements } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) - - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(authData) + 21 + 1 // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -318,9 +323,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 pos++ - // ScrambleBuffer [length encoded integer] - data[pos] = byte(len(scrambleBuff)) - pos += 1 + copy(data[pos+1:], scrambleBuff) + // authData [length encoded integer] + data[pos] = byte(len(authData)) + pos += 1 + copy(data[pos+1:], authData) // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -329,72 +334,22 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], authPluginName) data[pos] = 0x00 // Send Auth packet return mc.writePacket(data) } -// Client old authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { - // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the scrambled password [null terminated string] - copy(data[4:], scrambleBuff) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.Passwd) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Native password authentication method -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) +func (mc *mysqlConn) writeAuthDataPacket(authData []byte) error { + data := mc.buf.takeSmallBuffer(4 + len(authData)) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) return driver.ErrBadConn } - // Add the scramble - copy(data[4:], scrambleBuff) - + copy(data[4:], authData) return mc.writePacket(data) } @@ -480,29 +435,6 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) { case iOK: return nil, mc.handleOkPacket(data) - case iEOF: - if len(data) > 1 { - pluginEndIndex := bytes.IndexByte(data, 0x00) - plugin := string(data[1:pluginEndIndex]) - cipher := data[pluginEndIndex+1 : len(data)-1] - - if plugin == "mysql_old_password" { - // using old_passwords - return cipher, ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return cipher, ErrCleartextPassword - } else if plugin == "mysql_native_password" { - // using mysql default authentication method - return cipher, ErrNativePassword - } else { - return cipher, ErrUnknownPlugin - } - } else { - // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, ErrOldPassword - } - default: // Error otherwise return nil, mc.handleErrorPacket(data) }