From 57552492638af05ada8fbccfe410ad3e685cc780 Mon Sep 17 00:00:00 2001 From: tiancaiamao Date: Tue, 6 Sep 2016 18:28:15 +0800 Subject: [PATCH] server: support mysql CLIENT_CONNECT_ATTRS capability (#1684) --- server/conn.go | 113 +++++++++++++++++++++++++++++++++++++------- server/conn_test.go | 86 +++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+), 18 deletions(-) create mode 100644 server/conn_test.go diff --git a/server/conn.go b/server/conn.go index 4aa5e91b53fa9..ef421f82af58c 100644 --- a/server/conn.go +++ b/server/conn.go @@ -57,7 +57,8 @@ import ( var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows | - mysql.ClientMultiStatements | mysql.ClientMultiResults + mysql.ClientMultiStatements | mysql.ClientMultiResults | + mysql.ClientConnectAtts type clientConn struct { pkg *packetIO @@ -73,6 +74,7 @@ type clientConn struct { alloc arena.Allocator lastCmd string ctx IContext + attrs map[string]string } func (cc *clientConn) String() string { @@ -163,38 +165,113 @@ func (cc *clientConn) writePacket(data []byte) error { return cc.pkg.writePacket(data) } -func (cc *clientConn) readHandshakeResponse() error { - data, err := cc.readPacket() - if err != nil { - return errors.Trace(err) - } +type handshakeResponse41 struct { + Capability uint32 + Collation uint8 + User string + DBName string + Auth []byte + Attrs map[string]string +} +func handshakeResponseFromData(packet *handshakeResponse41, data []byte) error { pos := 0 // capability capability := binary.LittleEndian.Uint32(data[:4]) - cc.capability = defaultCapability & capability + packet.Capability = capability pos += 4 // skip max packet size pos += 4 // charset, skip, if you want to use another charset, use set names - cc.collation = data[pos] + packet.Collation = data[pos] pos++ // skip reserved 23[00] pos += 23 // user name - cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) - pos += len(cc.user) + 1 - // auth length and auth - authLen := int(data[pos]) - pos++ - auth := data[pos : pos+authLen] - pos += authLen - if cc.capability&mysql.ClientConnectWithDB > 0 { + packet.User = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) + pos += len(packet.User) + 1 + + if capability&mysql.ClientPluginAuthLenencClientData > 0 { + // TODO: Support mysql.ClientPluginAuthLenencClientData, skip it now + if num, null, off := parseLengthEncodedInt(data[pos:]); !null { + pos = pos + off + int(num) + } + } else if capability&mysql.ClientSecureConnection > 0 { + // auth length and auth + authLen := int(data[pos]) + pos++ + packet.Auth = data[pos : pos+authLen] + pos += authLen + } else { + packet.Auth = data[pos : pos+bytes.IndexByte(data[pos:], 0)] + pos += len(packet.Auth) + 1 + } + + if capability&mysql.ClientConnectWithDB > 0 { if len(data[pos:]) > 0 { idx := bytes.IndexByte(data[pos:], 0) - cc.dbname = string(data[pos : pos+idx]) + packet.DBName = string(data[pos : pos+idx]) + pos = pos + idx + 1 + } + } + + if capability&mysql.ClientPluginAuth > 0 { + // TODO: Support mysql.ClientPluginAuth, skip it now + idx := bytes.IndexByte(data[pos:], 0) + pos = pos + idx + 1 + } + + if capability&mysql.ClientConnectAtts > 0 { + if num, null, off := parseLengthEncodedInt(data[pos:]); !null { + pos += off + kv := data[pos : pos+int(num)] + attrs, err := parseAttrs(kv) + if err != nil { + return errors.Trace(err) + } + packet.Attrs = attrs + pos += int(num) } } + return nil +} + +func parseAttrs(data []byte) (map[string]string, error) { + attrs := make(map[string]string) + pos := 0 + for pos < len(data) { + key, _, off, err := parseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, errors.Trace(err) + } + pos += off + value, _, off, err := parseLengthEncodedBytes(data[pos:]) + if err != nil { + return attrs, errors.Trace(err) + } + pos += off + + attrs[string(key)] = string(value) + } + return attrs, nil +} + +func (cc *clientConn) readHandshakeResponse() error { + data, err := cc.readPacket() + if err != nil { + return errors.Trace(err) + } + + var p handshakeResponse41 + if err = handshakeResponseFromData(&p, data); err != nil { + return errors.Trace(err) + } + cc.capability = p.Capability & defaultCapability + cc.user = p.User + cc.dbname = p.DBName + cc.collation = p.Collation + cc.attrs = p.Attrs + // Open session and do auth cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname) if err != nil { @@ -209,7 +286,7 @@ func (cc *clientConn) readHandshakeResponse() error { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes")) } user := fmt.Sprintf("%s@%s", cc.user, host) - if !cc.ctx.Auth(user, auth, cc.salt) { + if !cc.ctx.Auth(user, p.Auth, cc.salt) { return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes")) } } diff --git a/server/conn_test.go b/server/conn_test.go new file mode 100644 index 0000000000000..c3201634cd5eb --- /dev/null +++ b/server/conn_test.go @@ -0,0 +1,86 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" +) + +type ConnTestSuite struct{} + +var _ = Suite(ConnTestSuite{}) + +func (ts ConnTestSuite) TestHandshakeResponseFromData(c *C) { + // test data from http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 + var p handshakeResponse41 + data := []byte{ + 0x85, 0xa2, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x40, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x14, 0x22, 0x50, 0x79, 0xa2, 0x12, 0xd4, + 0xe8, 0x82, 0xe5, 0xb3, 0xf4, 0x1a, 0x97, 0x75, 0x6b, 0xc8, 0xbe, 0xdb, 0x9f, 0x80, 0x6d, 0x79, + 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, + 0x6f, 0x72, 0x64, 0x00, 0x61, 0x03, 0x5f, 0x6f, 0x73, 0x09, 0x64, 0x65, 0x62, 0x69, 0x61, 0x6e, + 0x36, 0x2e, 0x30, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, + 0x08, 0x6c, 0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32, + 0x32, 0x33, 0x34, 0x34, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x76, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x08, 0x35, 0x2e, 0x36, 0x2e, 0x36, 0x2d, 0x6d, 0x39, 0x09, 0x5f, 0x70, + 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x66, + 0x6f, 0x6f, 0x03, 0x62, 0x61, 0x72, + } + err := handshakeResponseFromData(&p, data) + c.Assert(err, IsNil) + c.Assert(p.Capability&mysql.ClientConnectAtts, Equals, mysql.ClientConnectAtts) + eq := mapIdentical(p.Attrs, map[string]string{ + "_client_version": "5.6.6-m9", + "_platform": "x86_64", + "foo": "bar", + "_os": "debian6.0", + "_client_name": "libmysql", + "_pid": "22344"}) + c.Assert(eq, IsTrue) + + data = []byte{ + 0x8d, 0xa6, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x70, 0x61, 0x6d, 0x00, 0x14, 0xab, 0x09, 0xee, 0xf6, 0xbc, 0xb1, 0x32, + 0x3e, 0x61, 0x14, 0x38, 0x65, 0xc0, 0x99, 0x1d, 0x95, 0x7d, 0x75, 0xd4, 0x47, 0x74, 0x65, 0x73, + 0x74, 0x00, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x00, + } + p = handshakeResponse41{} + err = handshakeResponseFromData(&p, data) + c.Assert(err, IsNil) + capability := mysql.ClientProtocol41 | + mysql.ClientPluginAuth | + mysql.ClientSecureConnection | + mysql.ClientConnectWithDB + c.Assert(p.Capability&capability, Equals, capability) + c.Assert(p.User, Equals, "pam") + c.Assert(p.DBName, Equals, "test") +} + +func mapIdentical(m1, m2 map[string]string) bool { + return mapBelong(m1, m2) && mapBelong(m2, m1) +} + +func mapBelong(m1, m2 map[string]string) bool { + for k1, v1 := range m1 { + v2, ok := m2[k1] + if !ok && v1 != v2 { + return false + } + } + return true +}