Skip to content

Commit

Permalink
Add default connection attribute '_server_host' (#1506)
Browse files Browse the repository at this point in the history
The `_server_host` connection attribute is supported in MariaDB (Connector/C)
https://mariadb.com/kb/en/mysql_optionsv/#connection-attribute-options
  • Loading branch information
oblitorum authored Nov 23, 2023
1 parent a4c260b commit 98d7289
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 68 deletions.
2 changes: 2 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Janek Vedock <janekvedock at comcast.net>
Jason Ng <oblitorum at gmail.com>
Jean-Yves Pellé <jy at pelle.link>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Expand Down Expand Up @@ -131,6 +132,7 @@ Multiplay Ltd.
Percona LLC
PingCAP Inc.
Pivotal Inc.
Shattered Silicon Ltd.
Stripe Inc.
Zendesk Inc.
Dolthub Inc.
21 changes: 11 additions & 10 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ package mysql
import (
"context"
"database/sql/driver"
"fmt"
"net"
"os"
"strconv"
Expand All @@ -23,8 +22,8 @@ type connector struct {
encodedAttributes string // Encoded connection attributes.
}

func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf := make([]byte, 0, 251)
func encodeConnectionAttributes(cfg *Config) string {
connAttrsBuf := make([]byte, 0)

// default connection attributes
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
Expand All @@ -35,9 +34,14 @@ func encodeConnectionAttributes(textAttributes string) string {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
serverHost, _, _ := net.SplitHostPort(cfg.Addr)
if serverHost != "" {
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
}

// user-defined connection attributes
for _, connAttr := range strings.Split(textAttributes, ",") {
for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
k, v, found := strings.Cut(connAttr, ":")
if !found {
continue
Expand All @@ -49,15 +53,12 @@ func encodeConnectionAttributes(textAttributes string) string {
return string(connAttrsBuf)
}

func newConnector(cfg *Config) (*connector, error) {
encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes)
if len(encodedAttributes) > 250 {
return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes)
}
func newConnector(cfg *Config) *connector {
encodedAttributes := encodeConnectionAttributes(cfg)
return &connector{
cfg: cfg,
encodedAttributes: encodedAttributes,
}, nil
}
}

// Connect implements driver.Connector interface.
Expand Down
7 changes: 2 additions & 5 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ import (
)

func TestConnectorReturnsTimeout(t *testing.T) {
connector, err := newConnector(&Config{
connector := newConnector(&Config{
Net: "tcp",
Addr: "1.1.1.1:1234",
Timeout: 10 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}

_, err = connector.Connect(context.Background())
_, err := connector.Connect(context.Background())
if err == nil {
t.Fatal("error expected")
}
Expand Down
1 change: 1 addition & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
connAttrPlatform = "_platform"
connAttrPlatformValue = runtime.GOARCH
connAttrPid = "_pid"
connAttrServerHost = "_server_host"
)

// MySQL constants documentation:
Expand Down
9 changes: 3 additions & 6 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
c, err := newConnector(cfg)
if err != nil {
return nil, err
}
c := newConnector(cfg)
return c.Connect(context.Background())
}

Expand All @@ -108,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) {
if err := cfg.normalize(); err != nil {
return nil, err
}
return newConnector(cfg)
return newConnector(cfg), nil
}

// OpenConnector implements driver.DriverContext.
Expand All @@ -117,5 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
if err != nil {
return nil, err
}
return newConnector(cfg)
return newConnector(cfg), nil
}
71 changes: 37 additions & 34 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"os"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -3377,12 +3378,30 @@ func TestConnectionAttributes(t *testing.T) {
t.Skipf("MySQL server not running on %s", netAddr)
}

attr1 := "attr1"
value1 := "value1"
attr2 := "fo/o"
value2 := "bo/o"
dsn += "&connectionAttributes=" + url.QueryEscape(fmt.Sprintf("%s:%s,%s:%s", attr1, value1, attr2, value2))
defaultAttrs := []string{
connAttrClientName,
connAttrOS,
connAttrPlatform,
connAttrPid,
connAttrServerHost,
}
host, _, _ := net.SplitHostPort(addr)
defaultAttrValues := []string{
connAttrClientNameValue,
connAttrOSValue,
connAttrPlatformValue,
strconv.Itoa(os.Getpid()),
host,
}

customAttrs := []string{"attr1", "fo/o"}
customAttrValues := []string{"value1", "bo/o"}

customAttrStrs := make([]string, len(customAttrs))
for i := range customAttrs {
customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i])
}
dsn += "&connectionAttributes=" + url.QueryEscape(strings.Join(customAttrStrs, ","))

var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
Expand All @@ -3395,40 +3414,24 @@ func TestConnectionAttributes(t *testing.T) {

dbt := &DBTest{t, db}

var attrValue string
queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?"
rows := dbt.mustQuery(queryString, connAttrClientName)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != connAttrClientNameValue {
dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()"
rows := dbt.mustQuery(queryString)
defer rows.Close()

rows = dbt.mustQuery(queryString, attr1)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value1 {
dbt.Errorf("expected %q, got %q", value1, attrValue)
}
} else {
dbt.Errorf("no data")
rowsMap := make(map[string]string)
for rows.Next() {
var attrName, attrValue string
rows.Scan(&attrName, &attrValue)
rowsMap[attrName] = attrValue
}
rows.Close()

rows = dbt.mustQuery(queryString, attr2)
if rows.Next() {
rows.Scan(&attrValue)
if attrValue != value2 {
dbt.Errorf("expected %q, got %q", value2, attrValue)
connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...)
expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...)
for i := range connAttrs {
if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] {
dbt.Errorf("expected %q, got %q", expectedAttrValues[i], gotValue)
}
} else {
dbt.Errorf("no data")
}
rows.Close()
}

func TestErrorInMultiResult(t *testing.T) {
Expand Down
16 changes: 7 additions & 9 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pktLen += n + 1
}

// 1 byte to store length of all key-values
// NOTE: Actually, this is length encoded integer.
// But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer
// doesn't support buffer size more than 4096 bytes.
// TODO(methane): Rewrite buffer management.
pktLen += 1 + len(mc.connector.encodedAttributes)
// encode length of the connection attributes
var connAttrsLEIBuf [9]byte
connAttrsLen := len(mc.connector.encodedAttributes)
connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen))
pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes)

// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
Expand Down Expand Up @@ -380,8 +379,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pos++

// Connection Attributes
data[pos] = byte(len(mc.connector.encodedAttributes))
pos++
pos += copy(data[pos:], connAttrsLEI)
pos += copy(data[pos:], []byte(mc.connector.encodedAttributes))

// Send Auth packet
Expand Down
5 changes: 1 addition & 4 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ var _ net.Conn = new(mockConn)

func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
connector, err := newConnector(NewConfig())
if err != nil {
panic(err)
}
connector := newConnector(NewConfig())
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: connector.cfg,
Expand Down

0 comments on commit 98d7289

Please sign in to comment.