Skip to content

Commit

Permalink
*: Set collation to uint16
Browse files Browse the repository at this point in the history
  • Loading branch information
dveeden committed Apr 29, 2024
1 parent c60f97d commit 16f13bf
Show file tree
Hide file tree
Showing 24 changed files with 60 additions and 60 deletions.
10 changes: 5 additions & 5 deletions pkg/parser/charset/charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ type Charset struct {
// Collation is a collation.
// Now we only support MySQL.
type Collation struct {
ID int
ID uint16
CharsetName string
Name string
IsDefault bool
}

var collationsIDMap = make(map[int]*Collation)
var collationsIDMap = make(map[uint16]*Collation)
var collationsNameMap = make(map[string]*Collation)
var supportedCollations = make([]*Collation, 0, len(supportedCollationNames))

Expand Down Expand Up @@ -166,7 +166,7 @@ func GetCharsetInfo(cs string) (*Charset, error) {
}

// GetCharsetInfoByID returns charset and collation for id as cs_number.
func GetCharsetInfoByID(coID int) (charsetStr string, collateStr string, err error) {
func GetCharsetInfoByID(coID uint16) (charsetStr string, collateStr string, err error) {
if coID == mysql.DefaultCollationID {
return mysql.DefaultCharset, mysql.DefaultCollationName, nil
}
Expand All @@ -176,7 +176,7 @@ func GetCharsetInfoByID(coID int) (charsetStr string, collateStr string, err err

log.Warn(
"unable to get collation name from collation ID, return default charset and collation instead",
zap.Int("ID", coID),
zap.Uint16("ID", coID),
zap.Stack("stack"))
return mysql.DefaultCharset, mysql.DefaultCollationName, errors.Errorf("Unknown collation id %d", coID)
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func GetCollationByName(name string) (*Collation, error) {
}

// GetCollationByID returns collations by given id.
func GetCollationByID(id int) (*Collation, error) {
func GetCollationByID(id uint16) (*Collation, error) {
collation, ok := collationsIDMap[id]
if !ok {
return nil, errors.Errorf("Unknown collation id %d", id)
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/charset/charset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestGetCollationByName(t *testing.T) {
func TestValidCustomCharset(t *testing.T) {
AddCharset(&Charset{"custom", "custom_collation", make(map[string]*Collation), "Custom", 4})
defer RemoveCharset("custom")
AddCollation(&Collation{99999, "custom", "custom_collation", true})
AddCollation(&Collation{9999, "custom", "custom_collation", true})

tests := []struct {
cs string
Expand Down
16 changes: 8 additions & 8 deletions pkg/parser/mysql/charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package mysql
import "unicode"

// CharsetNameToID maps charset name to its default collation ID.
func CharsetNameToID(charset string) uint8 {
func CharsetNameToID(charset string) uint16 {
// Use quick path for TiDB to avoid access CharsetIDs map
// "SHOW CHARACTER SET;" to see all the supported character sets.
if charset == "utf8mb4" {
Expand All @@ -34,7 +34,7 @@ func CharsetNameToID(charset string) uint8 {
}

// CharsetIDs maps charset name to its default collation ID.
var CharsetIDs = map[string]uint8{
var CharsetIDs = map[string]uint16{
"big5": 1,
"dec8": 3,
"cp850": 4,
Expand Down Expand Up @@ -533,12 +533,12 @@ const (
UTF8MB4Charset = "utf8mb4"
DefaultCharset = UTF8MB4Charset
// DefaultCollationID is utf8mb4_bin(46)
DefaultCollationID = 46
Latin1DefaultCollationID = 47
ASCIIDefaultCollationID = 65
UTF8DefaultCollationID = 83
UTF8MB4DefaultCollationID = 46
BinaryDefaultCollationID = 63
DefaultCollationID = uint16(46)
Latin1DefaultCollationID = uint16(47)
ASCIIDefaultCollationID = uint16(65)
UTF8DefaultCollationID = uint16(83)
UTF8MB4DefaultCollationID = uint16(46)
BinaryDefaultCollationID = uint16(63)
UTF8MB4DefaultCollation = "utf8mb4_bin"
DefaultCollationName = UTF8MB4DefaultCollation

Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/property/physical_property.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func GetCollateIDByNameForPartition(coll string) int32 {
// GetCollateNameByIDForPartition returns collate id by collation name
func GetCollateNameByIDForPartition(collateID int32) string {
collateID = collate.RestoreCollationIDIfNeeded(collateID)
return collate.CollationID2Name(collateID)
return collate.CollationID2Name(uint16(collateID))
}

// cteProducerStatus indicates whether we can let the current CTE consumer/reader be executed on the MPP nodes.
Expand Down
8 changes: 4 additions & 4 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ type clientConn struct {
peerPort string // peer port
status int32 // dispatching/reading/shutdown/waitshutdown
lastCode uint16 // last error code
collation uint8 // collation used by client, may be different from the collation used by database.
collation uint16 // collation used by client, may be different from the collation used by database.
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
Expand Down Expand Up @@ -220,7 +220,7 @@ func (cc *clientConn) String() string {
// MySQL converts a collation from u32 to char in the protocol, so the value could be wrong. It works fine for the
// default parameters (and libmysql seems not to provide any way to specify the collation other than the default
// one), so it's not a big problem.
collationStr := mysql.Collations[uint16(cc.collation)]
collationStr := mysql.Collations[cc.collation]
return fmt.Sprintf("id:%d, addr:%s status:%b, collation:%s, user:%s",
cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user,
)
Expand Down Expand Up @@ -428,9 +428,9 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8))
// charset
if cc.collation == 0 {
cc.collation = uint8(mysql.DefaultCollationID)
cc.collation = mysql.DefaultCollationID
}
data = append(data, cc.collation)
data = append(data, byte(cc.collation))
// status
data = dump.Uint16(data, mysql.ServerStatusAutocommit)
// below 13 byte may not be used
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/conn_stmt_params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func expectedDatetimeExecuteResult(t *testing.T, c *mockConn, time types.Time, w
Name: "t",
Table: "",
Type: mysql.TypeDatetime,
Charset: uint16(mysql.CharsetNameToID(charset.CharsetBin)),
Charset: mysql.CharsetNameToID(charset.CharsetBin),
Flag: uint16(mysql.NotNullFlag | mysql.BinaryFlag),
Decimal: 6,
ColumnLength: 26,
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func TestInitialHandshake(t *testing.T) {
expected.Write([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00}) // Salt
err = binary.Write(expected, binary.LittleEndian, uint16(defaultCapability&0xFFFF)) // Server Capability
require.NoError(t, err)
expected.WriteByte(uint8(mysql.DefaultCollationID)) // Server Language
expected.WriteByte(mysql.DefaultCollationID) // Server Language
err = binary.Write(expected, binary.LittleEndian, mysql.ServerStatusAutocommit) // Server Status
require.NoError(t, err)
err = binary.Write(expected, binary.LittleEndian, uint16((defaultCapability>>16)&0xFFFF)) // Extended Server Capability
Expand Down Expand Up @@ -1594,7 +1594,7 @@ func TestAuthSessionTokenPlugin(t *testing.T) {
tk.MustExec("CREATE USER auth_session_token")
tk.MustExec("CREATE USER another_user")

tc, err := drv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "", nil, nil)
tc, err := drv.OpenCtx(uint64(0), 0, mysql.DefaultCollationID, "", nil, nil)
require.NoError(t, err)
cc := &clientConn{
connectionID: 1,
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
// IDriver opens IContext.
type IDriver interface {
// OpenCtx opens an IContext with connection id, client capability, collation, dbname and optionally the tls state.
OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error)
OpenCtx(connID uint64, capability uint32, collation uint16, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error)
}

// PreparedStatement is the interface to use a prepared statement.
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ func (ts *TiDBStatement) GetRowContainer() *chunk.RowContainer {
}

// OpenCtx implements IDriver.
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, _ string,
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint16, _ string,
tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) {
se, err := session.CreateSession(qd.store)
if err != nil {
return nil, err
}
se.SetTLSState(tlsState)
err = se.SetCollation(int(collation))
err = se.SetCollation(collation)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/driver_tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func createColumnByTypeAndLen(tp byte, cl uint32) *column.Info {
Name: "a",
OrgName: "a",
ColumnLength: cl,
Charset: uint16(mysql.CharsetNameToID(charset.CharsetUTF8)),
Charset: mysql.CharsetNameToID(charset.CharsetUTF8),
Flag: uint16(mysql.UnsignedFlag),
Decimal: uint8(0),
Type: tp,
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/internal/column/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func TestDumpTextValue(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []byte{0xd2, 0xbb}, []byte(mustDecodeStr(t, bs)))

columns[0].Charset = uint16(mysql.CharsetNameToID("gbk"))
columns[0].Charset = mysql.CharsetNameToID("gbk")
dp = NewResultEncoder("binary")
bs, err = DumpTextRow(nil, columns, chunk.MutRowFromDatums(dt).ToRow(), dp)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/internal/column/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func ConvertColumnInfo(fld *ast.ResultField) (ci *Info) {
Table: fld.TableAsName.O,
Schema: fld.DBName.O,
Flag: uint16(fld.Column.GetFlag()),
Charset: uint16(mysql.CharsetNameToID(fld.Column.GetCharset())),
Charset: mysql.CharsetNameToID(fld.Column.GetCharset()),
Type: fld.Column.GetType(),
DefaultValue: fld.Column.GetDefaultValue(),
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/internal/column/result_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (d *ResultEncoder) Clean() {

// UpdateDataEncoding updates the data encoding.
func (d *ResultEncoder) UpdateDataEncoding(chsID uint16) {
chs, _, err := charset.GetCharsetInfoByID(int(chsID))
chs, _, err := charset.GetCharsetInfoByID(chsID)
if err != nil {
logutil.BgLogger().Warn("unknown charset ID", zap.Error(err))
}
Expand All @@ -98,7 +98,7 @@ func (d *ResultEncoder) ColumnTypeInfoCharsetID(info *Info) uint16 {
if info.Charset == mysql.BinaryDefaultCollationID {
return mysql.BinaryDefaultCollationID
}
return uint16(mysql.CharsetNameToID(d.chsName))
return mysql.CharsetNameToID(d.chsName)
}

// EncodeMeta encodes bytes for meta info like column names.
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/internal/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ type Response41 struct {
Auth []byte
ZstdLevel int
Capability uint32
Collation uint8
Collation uint16
}
8 changes: 4 additions & 4 deletions pkg/server/internal/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ func HandshakeResponseHeader(ctx context.Context, packet *handshake.Response41,
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
packet.Collation = binary.LittleEndian.Uint16(data[offset : offset+2])
offset += 2
// skip reserved 22[00]
offset += 22

return offset, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func CreateMockConn(t *testing.T, server *Server) MockConn {
require.NoError(t, err)

connID := rand.Uint64()
tc, err := server.driver.OpenCtx(connID, 0, uint8(tmysql.DefaultCollationID), "", nil, extensions.NewSessionExtensions())
tc, err := server.driver.OpenCtx(connID, 0, tmysql.DefaultCollationID, "", nil, extensions.NewSessionExtensions())
require.NoError(t, err)

cc := &clientConn{
Expand Down
16 changes: 8 additions & 8 deletions pkg/server/tests/commontest/tidb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ func TestCreateTableFlen(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

// issue #4540
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
_, err = Execute(context.Background(), qctx, "use test;")
require.NoError(t, err)
Expand Down Expand Up @@ -670,7 +670,7 @@ func Execute(ctx context.Context, qc *server2.TiDBContext, sql string) (resultse
func TestShowTablesFlen(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
ctx := context.Background()
_, err = Execute(ctx, qctx, "use test;")
Expand Down Expand Up @@ -700,7 +700,7 @@ func checkColNames(t *testing.T, columns []*column.Info, names ...string) {
func TestFieldList(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
_, err = Execute(context.Background(), qctx, "use test;")
require.NoError(t, err)
Expand Down Expand Up @@ -798,7 +798,7 @@ func TestSumAvg(t *testing.T) {
func TestNullFlag(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)

ctx := context.Background()
Expand Down Expand Up @@ -872,7 +872,7 @@ func TestNO_DEFAULT_VALUEFlag(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

// issue #21465
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)

ctx := context.Background()
Expand Down Expand Up @@ -935,7 +935,7 @@ func TestGracefulShutdown(t *testing.T) {
func TestPessimisticInsertSelectForUpdate(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
defer qctx.Close()
ctx := context.Background()
Expand Down Expand Up @@ -2471,7 +2471,7 @@ func TestExtensionConnEvent(t *testing.T) {

func TestSandBoxMode(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
_, err = Execute(context.Background(), qctx, "create user testuser;")
require.NoError(t, err)
Expand Down Expand Up @@ -3053,7 +3053,7 @@ func TestConnectionWillNotLeak(t *testing.T) {
func TestPrepareCount(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)
prepareCnt := atomic.LoadInt64(&variable.PreparedStmtCount)
ctx := context.Background()
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/tests/tidb_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestLoadDataListPartition(t *testing.T) {
func TestPrepareExecute(t *testing.T) {
ts := servertestkit.CreateTidbTestSuite(t)

qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(tmysql.DefaultCollationID), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, tmysql.DefaultCollationID, "test", nil, nil)
require.NoError(t, err)

ctx := context.Background()
Expand Down Expand Up @@ -150,7 +150,7 @@ func TestDefaultCharacterAndCollation(t *testing.T) {

// issue #21194
// 255 is the collation id of mysql client 8 default collation_connection
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint8(255), "test", nil, nil)
qctx, err := ts.Tidbdrv.OpenCtx(uint64(0), 0, uint16(255), "test", nil, nil)
require.NoError(t, err)
testCase := []struct {
variable string
Expand Down
2 changes: 1 addition & 1 deletion pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (s *session) SetCommandValue(command byte) {
atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command))
}

func (s *session) SetCollation(coID int) error {
func (s *session) SetCollation(coID uint16) error {
cs, co, err := charset.GetCharsetInfoByID(coID)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/session/types/sesson_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type Session interface {
SetCompressionLevel(int)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
SetCollation(coID uint16) error
SetSessionManager(util.SessionManager)
Close()
Auth(user *auth.UserIdentity, auth, salt []byte, authConn conn.AuthConn) error
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ func extractOffsetsInExpr(expr *tipb.Expr, columns []*tipb.ColumnInfo, collector

// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo.
func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType {
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(uint16(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
ft := &types.FieldType{}
ft.SetType(byte(col.GetTp()))
ft.SetFlag(uint(col.GetFlag()))
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/unistore/cophandler/cop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ func appendRow(chunks []tipb.Chunk, data []byte, rowCnt int) []tipb.Chunk {

// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo.
func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType {
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(uint16(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
ft := &types.FieldType{}
ft.SetType(byte(col.GetTp()))
ft.SetFlag(uint(col.GetFlag()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func prepareTestTableData(keyNumber int, tableID int64) (*data, error) {
colInfos[i] = &tipb.ColumnInfo{
ColumnId: colIds[i],
Tp: int32(colTypes[i].GetType()),
Collation: -mysql.DefaultCollationID,
Collation: int32(mysql.DefaultCollationID),
}
colTypeMap[colIds[i]] = colTypes[i]
}
Expand Down
Loading

0 comments on commit 16f13bf

Please sign in to comment.