Skip to content

Commit

Permalink
server: support mysql CLIENT_CONNECT_ATTRS capability (pingcap#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiancaiamao authored Sep 6, 2016
1 parent 9257a41 commit 5755249
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 18 deletions.
113 changes: 95 additions & 18 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ import (
var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
mysql.ClientMultiStatements | mysql.ClientMultiResults
mysql.ClientMultiStatements | mysql.ClientMultiResults |
mysql.ClientConnectAtts

type clientConn struct {
pkg *packetIO
Expand All @@ -73,6 +74,7 @@ type clientConn struct {
alloc arena.Allocator
lastCmd string
ctx IContext
attrs map[string]string
}

func (cc *clientConn) String() string {
Expand Down Expand Up @@ -163,38 +165,113 @@ func (cc *clientConn) writePacket(data []byte) error {
return cc.pkg.writePacket(data)
}

func (cc *clientConn) readHandshakeResponse() error {
data, err := cc.readPacket()
if err != nil {
return errors.Trace(err)
}
type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
Attrs map[string]string
}

func handshakeResponseFromData(packet *handshakeResponse41, data []byte) error {
pos := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
cc.capability = defaultCapability & capability
packet.Capability = capability
pos += 4
// skip max packet size
pos += 4
// charset, skip, if you want to use another charset, use set names
cc.collation = data[pos]
packet.Collation = data[pos]
pos++
// skip reserved 23[00]
pos += 23
// user name
cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
pos += len(cc.user) + 1
// auth length and auth
authLen := int(data[pos])
pos++
auth := data[pos : pos+authLen]
pos += authLen
if cc.capability&mysql.ClientConnectWithDB > 0 {
packet.User = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
pos += len(packet.User) + 1

if capability&mysql.ClientPluginAuthLenencClientData > 0 {
// TODO: Support mysql.ClientPluginAuthLenencClientData, skip it now
if num, null, off := parseLengthEncodedInt(data[pos:]); !null {
pos = pos + off + int(num)
}
} else if capability&mysql.ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[pos])
pos++
packet.Auth = data[pos : pos+authLen]
pos += authLen
} else {
packet.Auth = data[pos : pos+bytes.IndexByte(data[pos:], 0)]
pos += len(packet.Auth) + 1
}

if capability&mysql.ClientConnectWithDB > 0 {
if len(data[pos:]) > 0 {
idx := bytes.IndexByte(data[pos:], 0)
cc.dbname = string(data[pos : pos+idx])
packet.DBName = string(data[pos : pos+idx])
pos = pos + idx + 1
}
}

if capability&mysql.ClientPluginAuth > 0 {
// TODO: Support mysql.ClientPluginAuth, skip it now
idx := bytes.IndexByte(data[pos:], 0)
pos = pos + idx + 1
}

if capability&mysql.ClientConnectAtts > 0 {
if num, null, off := parseLengthEncodedInt(data[pos:]); !null {
pos += off
kv := data[pos : pos+int(num)]
attrs, err := parseAttrs(kv)
if err != nil {
return errors.Trace(err)
}
packet.Attrs = attrs
pos += int(num)
}
}
return nil
}

func parseAttrs(data []byte) (map[string]string, error) {
attrs := make(map[string]string)
pos := 0
for pos < len(data) {
key, _, off, err := parseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, errors.Trace(err)
}
pos += off
value, _, off, err := parseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, errors.Trace(err)
}
pos += off

attrs[string(key)] = string(value)
}
return attrs, nil
}

func (cc *clientConn) readHandshakeResponse() error {
data, err := cc.readPacket()
if err != nil {
return errors.Trace(err)
}

var p handshakeResponse41
if err = handshakeResponseFromData(&p, data); err != nil {
return errors.Trace(err)
}
cc.capability = p.Capability & defaultCapability
cc.user = p.User
cc.dbname = p.DBName
cc.collation = p.Collation
cc.attrs = p.Attrs

// Open session and do auth
cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname)
if err != nil {
Expand All @@ -209,7 +286,7 @@ func (cc *clientConn) readHandshakeResponse() error {
return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes"))
}
user := fmt.Sprintf("%s@%s", cc.user, host)
if !cc.ctx.Auth(user, auth, cc.salt) {
if !cc.ctx.Auth(user, p.Auth, cc.salt) {
return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes"))
}
}
Expand Down
86 changes: 86 additions & 0 deletions server/conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/mysql"
)

type ConnTestSuite struct{}

var _ = Suite(ConnTestSuite{})

func (ts ConnTestSuite) TestHandshakeResponseFromData(c *C) {
// test data from http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41
var p handshakeResponse41
data := []byte{
0x85, 0xa2, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x40, 0x08, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x14, 0x22, 0x50, 0x79, 0xa2, 0x12, 0xd4,
0xe8, 0x82, 0xe5, 0xb3, 0xf4, 0x1a, 0x97, 0x75, 0x6b, 0xc8, 0xbe, 0xdb, 0x9f, 0x80, 0x6d, 0x79,
0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77,
0x6f, 0x72, 0x64, 0x00, 0x61, 0x03, 0x5f, 0x6f, 0x73, 0x09, 0x64, 0x65, 0x62, 0x69, 0x61, 0x6e,
0x36, 0x2e, 0x30, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65,
0x08, 0x6c, 0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32,
0x32, 0x33, 0x34, 0x34, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x76, 0x65, 0x72,
0x73, 0x69, 0x6f, 0x6e, 0x08, 0x35, 0x2e, 0x36, 0x2e, 0x36, 0x2d, 0x6d, 0x39, 0x09, 0x5f, 0x70,
0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x66,
0x6f, 0x6f, 0x03, 0x62, 0x61, 0x72,
}
err := handshakeResponseFromData(&p, data)
c.Assert(err, IsNil)
c.Assert(p.Capability&mysql.ClientConnectAtts, Equals, mysql.ClientConnectAtts)
eq := mapIdentical(p.Attrs, map[string]string{
"_client_version": "5.6.6-m9",
"_platform": "x86_64",
"foo": "bar",
"_os": "debian6.0",
"_client_name": "libmysql",
"_pid": "22344"})
c.Assert(eq, IsTrue)

data = []byte{
0x8d, 0xa6, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x70, 0x61, 0x6d, 0x00, 0x14, 0xab, 0x09, 0xee, 0xf6, 0xbc, 0xb1, 0x32,
0x3e, 0x61, 0x14, 0x38, 0x65, 0xc0, 0x99, 0x1d, 0x95, 0x7d, 0x75, 0xd4, 0x47, 0x74, 0x65, 0x73,
0x74, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00,
}
p = handshakeResponse41{}
err = handshakeResponseFromData(&p, data)
c.Assert(err, IsNil)
capability := mysql.ClientProtocol41 |
mysql.ClientPluginAuth |
mysql.ClientSecureConnection |
mysql.ClientConnectWithDB
c.Assert(p.Capability&capability, Equals, capability)
c.Assert(p.User, Equals, "pam")
c.Assert(p.DBName, Equals, "test")
}

func mapIdentical(m1, m2 map[string]string) bool {
return mapBelong(m1, m2) && mapBelong(m2, m1)
}

func mapBelong(m1, m2 map[string]string) bool {
for k1, v1 := range m1 {
v2, ok := m2[k1]
if !ok && v1 != v2 {
return false
}
}
return true
}

0 comments on commit 5755249

Please sign in to comment.