Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for authentication plugins. #552

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Aaron Hopkins <go-sql-driver at die.net>
Arne Hormann <arnehormann at gmail.com>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Moos <chris at tech9computers.com>
Craig Wilson <craiggwilson@gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
DisposaBoy <disposaboy at dby.me>
Expand Down
139 changes: 139 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package mysql

import "bytes"

const (
mysqlClearPassword = "mysql_clear_password"
mysqlNativePassword = "mysql_native_password"
mysqlOldPassword = "mysql_old_password"
defaultAuthPluginName = mysqlNativePassword
)

var authPluginFactories map[string]func(*Config) AuthPlugin

func init() {
authPluginFactories = make(map[string]func(*Config) AuthPlugin)
authPluginFactories[mysqlClearPassword] = func(cfg *Config) AuthPlugin {
return &clearTextPlugin{cfg}
}
authPluginFactories[mysqlNativePassword] = func(cfg *Config) AuthPlugin {
return &nativePasswordPlugin{cfg}
}
authPluginFactories[mysqlOldPassword] = func(cfg *Config) AuthPlugin {
return &oldPasswordPlugin{cfg}
}
}

// RegisterAuthPlugin registers an authentication plugin to be used during
// negotiation with the server. If a plugin with the given name already exists,
// it will be overwritten.
func RegisterAuthPlugin(name string, factory func(*Config) AuthPlugin) {
authPluginFactories[name] = factory
}

// AuthPlugin handles authenticating a user.
type AuthPlugin interface {
// Next takes a server's challenge and returns
// the bytes to send back or an error.
Next(challenge []byte) ([]byte, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this called Next and not e.g. Auth (which I believe is what this does)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is called next because it might be called more than once. All the current plugins, clear, native, and old all are single step. Next is called once and they are complete. However, not all plugins will be this way, so calling it Auth and having it called multiple times is weird (to me). Precedent was pulled from Java's driver, which calls it Next as well. However, I'm certainly not tied to that name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Maybe also add a short comment to the code that it might be called more than once.


// Close cleans up the resources of the plugin.
Close()
}

type clearTextPlugin struct {
cfg *Config
}

func (p *clearTextPlugin) Next(challenge []byte) ([]byte, error) {
if !p.cfg.AllowCleartextPasswords {
return nil, ErrCleartextPassword
}

// NUL-terminated
return append([]byte(p.cfg.Passwd), 0), nil
}

func (p *clearTextPlugin) Close() {}

type nativePasswordPlugin struct {
cfg *Config
}

func (p *nativePasswordPlugin) Next(challenge []byte) ([]byte, error) {
// NOTE: this seems to always be disabled...
// if !p.cfg.AllowNativePasswords {
// return nil, ErrNativePassword
// }

return scramblePassword(challenge, []byte(p.cfg.Passwd)), nil
}

func (p *nativePasswordPlugin) Close() {}

type oldPasswordPlugin struct {
cfg *Config
}

func (p *oldPasswordPlugin) Next(challenge []byte) ([]byte, error) {
if !p.cfg.AllowOldPasswords {
return nil, ErrOldPassword
}

// NUL-terminated
return append(scrambleOldPassword(challenge, []byte(p.cfg.Passwd)), 0), nil
}

func (p *oldPasswordPlugin) Close() {}

func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
data, err := mc.readPacket()
if err != nil {
return err
}

var authData []byte

// packet indicator
switch data[0] {
case iOK:
return mc.handleOkPacket(data)

case iEOF: // auth switch
mc.authPlugin.Close()
if len(data) > 1 {
pluginEndIndex := bytes.IndexByte(data, 0x00)
pluginName := string(data[1:pluginEndIndex])
if apf, ok := authPluginFactories[pluginName]; ok {
mc.authPlugin = apf(mc.cfg)
} else {
return ErrUnknownPlugin
}

if len(data) > pluginEndIndex+1 {
authData = data[pluginEndIndex+1 : len(data)-1]
}
} else {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
mc.authPlugin = authPluginFactories[mysqlOldPassword](mc.cfg)
authData = oldCipher
}
case iAuthContinue:
// continue packet for a plugin.
authData = data[1:] // strip off the continue flag
default: // Error otherwise
return mc.handleErrorPacket(data)
}

authData, err = mc.authPlugin.Next(authData)
if err != nil {
return err
}

err = mc.writeAuthDataPacket(authData)
if err != nil {
return err
}

return handleAuthResult(mc, authData)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like recursion. Please loop outside of this function.

}
72 changes: 72 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mysql

import "testing"
import "bytes"

func TestAuthPlugin_Cleartext(t *testing.T) {
cfg := &Config{
Passwd: "funny",
}

plugin := authPluginFactories[mysqlClearPassword](cfg)

_, err := plugin.Next(nil)
if err == nil {
t.Fatalf("expected error when AllowCleartextPasswords is false")
}

cfg.AllowCleartextPasswords = true

actual, err := plugin.Next(nil)
if err != nil {
t.Fatalf("expected no error but got: %s", err)
}

expected := append([]byte("funny"), 0)
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("expected data to be %v, but got: %v", expected, actual)
}
}

func TestAuthPlugin_NativePassword(t *testing.T) {
cfg := &Config{
Passwd: "pass ",
}

plugin := authPluginFactories[mysqlNativePassword](cfg)

actual, err := plugin.Next([]byte{9, 8, 7, 6, 5, 4, 3, 2})
if err != nil {
t.Fatalf("expected no error but got: %s", err)
}

expected := []byte{195, 146, 3, 213, 111, 95, 252, 192, 97, 226, 173, 176, 91, 175, 131, 138, 89, 45, 75, 179}
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("expected data to be %v, but got: %v", expected, actual)
}
}

func TestAuthPlugin_OldPassword(t *testing.T) {
cfg := &Config{
Passwd: "pass ",
}

plugin := authPluginFactories[mysqlOldPassword](cfg)

_, err := plugin.Next(nil)
if err == nil {
t.Fatalf("expected error when AllowOldPasswords is false")
}

cfg.AllowOldPasswords = true

actual, err := plugin.Next([]byte{9, 8, 7, 6, 5, 4, 3, 2})
if err != nil {
t.Fatalf("expected no error but got: %s", err)
}

expected := []byte{71, 87, 92, 90, 67, 91, 66, 81, 0}
if bytes.Compare(actual, expected) != 0 {
t.Fatalf("expected data to be %v, but got: %v", expected, actual)
}
}
5 changes: 5 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type mysqlConn struct {
sequence uint8
parseTime bool
strict bool
authPlugin AuthPlugin
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -92,6 +93,10 @@ func (mc *mysqlConn) Close() (err error) {
// closed the network connection.
func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent
if mc.authPlugin != nil {
mc.authPlugin.Close()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auth plugin is not used again after the connection is established and can therefore be closed in the init phase already, I believe. We also don't have to add it as a field to mc then.

mc.authPlugin = nil
}
if mc.netConn != nil {
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
Expand Down
9 changes: 5 additions & 4 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ const (
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html

const (
iOK byte = 0x00
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
iOK byte = 0x00
iAuthContinue byte = 0x01
iLocalInFile byte = 0xfb
iEOF byte = 0xfe
iERR byte = 0xff
)

// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
Expand Down
80 changes: 33 additions & 47 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,50 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.writeTimeout = mc.cfg.WriteTimeout

// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
authPluginName, authData, err := mc.readInitPacket()
if err != nil {
mc.cleanup()
return nil, err
}

// save the old auth data in case the server
// needs to use the old password scheme.
oldCipher := make([]byte, len(authData))
copy(oldCipher, authData)

// Handle pluggable authentication
if authPluginName == "" {
// assume that without a name, we are using
// the default.
authPluginName = defaultAuthPluginName
}

if apf, ok := authPluginFactories[authPluginName]; ok {
mc.authPlugin = apf(mc.cfg)
authData, err = mc.authPlugin.Next(authData)
if err != nil {
mc.cleanup()
return nil, err
}
} else {
// we'll tell the server in response that we are switching to our
// default plugin because we didn't recognize the one they sent us.
authPluginName = defaultAuthPluginName
mc.authPlugin = authPluginFactories[authPluginName](mc.cfg)

// zero-out the authData because the current authData was for
// a plugin we don't know about.
authData = make([]byte, 0)
}

// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
if err = mc.writeAuthPacket(authPluginName, authData); err != nil {
mc.cleanup()
return nil, err
}

// Handle response to auth packet, switch methods if possible
if err = handleAuthResult(mc, cipher); err != nil {
if err = handleAuthResult(mc, oldCipher); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
Expand Down Expand Up @@ -134,50 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return mc, nil
}

func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
// Read Result Packet
cipher, err := mc.readResultOK()
if err == nil {
return nil // auth successful
}

if mc.cfg == nil {
return err // auth failed and retry not possible
}

// Retry auth if configured to do so.
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
// Retry with old authentication method. Note: there are edge cases
// where this should work but doesn't; this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184

// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
// sent and we have to keep using the cipher sent in the init packet.
if cipher == nil {
cipher = oldCipher
}

if err = mc.writeOldAuthPacket(cipher); err != nil {
return err
}
_, err = mc.readResultOK()
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
// Retry with clear text password for
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
if err = mc.writeClearAuthPacket(); err != nil {
return err
}
_, err = mc.readResultOK()
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword {
if err = mc.writeNativeAuthPacket(cipher); err != nil {
return err
}
_, err = mc.readResultOK()
}
return err
}

func init() {
sql.Register("mysql", &MySQLDriver{})
}
Loading