Skip to content

Commit

Permalink
*: Set collation to uint16
Browse files Browse the repository at this point in the history
  • Loading branch information
dveeden committed Apr 29, 2024
1 parent c60f97d commit b1649e3
Show file tree
Hide file tree
Showing 17 changed files with 35 additions and 35 deletions.
10 changes: 5 additions & 5 deletions pkg/parser/charset/charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ type Charset struct {
// Collation is a collation.
// Now we only support MySQL.
type Collation struct {
ID int
ID uint16
CharsetName string
Name string
IsDefault bool
}

var collationsIDMap = make(map[int]*Collation)
var collationsIDMap = make(map[uint16]*Collation)
var collationsNameMap = make(map[string]*Collation)
var supportedCollations = make([]*Collation, 0, len(supportedCollationNames))

Expand Down Expand Up @@ -166,7 +166,7 @@ func GetCharsetInfo(cs string) (*Charset, error) {
}

// GetCharsetInfoByID returns charset and collation for id as cs_number.
func GetCharsetInfoByID(coID int) (charsetStr string, collateStr string, err error) {
func GetCharsetInfoByID(coID uint16) (charsetStr string, collateStr string, err error) {
if coID == mysql.DefaultCollationID {
return mysql.DefaultCharset, mysql.DefaultCollationName, nil
}
Expand All @@ -176,7 +176,7 @@ func GetCharsetInfoByID(coID int) (charsetStr string, collateStr string, err err

log.Warn(
"unable to get collation name from collation ID, return default charset and collation instead",
zap.Int("ID", coID),
zap.Uint16("ID", coID),
zap.Stack("stack"))
return mysql.DefaultCharset, mysql.DefaultCollationName, errors.Errorf("Unknown collation id %d", coID)
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func GetCollationByName(name string) (*Collation, error) {
}

// GetCollationByID returns collations by given id.
func GetCollationByID(id int) (*Collation, error) {
func GetCollationByID(id uint16) (*Collation, error) {
collation, ok := collationsIDMap[id]
if !ok {
return nil, errors.Errorf("Unknown collation id %d", id)
Expand Down
2 changes: 1 addition & 1 deletion pkg/parser/charset/charset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestGetCollationByName(t *testing.T) {
func TestValidCustomCharset(t *testing.T) {
AddCharset(&Charset{"custom", "custom_collation", make(map[string]*Collation), "Custom", 4})
defer RemoveCharset("custom")
AddCollation(&Collation{99999, "custom", "custom_collation", true})
AddCollation(&Collation{9999, "custom", "custom_collation", true})

tests := []struct {
cs string
Expand Down
4 changes: 2 additions & 2 deletions pkg/parser/mysql/charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package mysql
import "unicode"

// CharsetNameToID maps charset name to its default collation ID.
func CharsetNameToID(charset string) uint8 {
func CharsetNameToID(charset string) uint16 {
// Use quick path for TiDB to avoid access CharsetIDs map
// "SHOW CHARACTER SET;" to see all the supported character sets.
if charset == "utf8mb4" {
Expand All @@ -34,7 +34,7 @@ func CharsetNameToID(charset string) uint8 {
}

// CharsetIDs maps charset name to its default collation ID.
var CharsetIDs = map[string]uint8{
var CharsetIDs = map[string]uint16{
"big5": 1,
"dec8": 3,
"cp850": 4,
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/property/physical_property.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func GetCollateIDByNameForPartition(coll string) int32 {
// GetCollateNameByIDForPartition returns collate id by collation name
func GetCollateNameByIDForPartition(collateID int32) string {
collateID = collate.RestoreCollationIDIfNeeded(collateID)
return collate.CollationID2Name(collateID)
return collate.CollationID2Name(uint16(collateID))
}

// cteProducerStatus indicates whether we can let the current CTE consumer/reader be executed on the MPP nodes.
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ type clientConn struct {
peerPort string // peer port
status int32 // dispatching/reading/shutdown/waitshutdown
lastCode uint16 // last error code
collation uint8 // collation used by client, may be different from the collation used by database.
collation uint16 // collation used by client, may be different from the collation used by database.
lastActive time.Time // last active time
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
Expand Down Expand Up @@ -428,9 +428,9 @@ func (cc *clientConn) writeInitialHandshake(ctx context.Context) error {
data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8))
// charset
if cc.collation == 0 {
cc.collation = uint8(mysql.DefaultCollationID)
cc.collation = uint16(mysql.DefaultCollationID)
}
data = append(data, cc.collation)
data = append(data, byte(cc.collation))
// status
data = dump.Uint16(data, mysql.ServerStatusAutocommit)
// below 13 byte may not be used
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ func TestAuthSessionTokenPlugin(t *testing.T) {
tk.MustExec("CREATE USER auth_session_token")
tk.MustExec("CREATE USER another_user")

tc, err := drv.OpenCtx(uint64(0), 0, uint8(mysql.DefaultCollationID), "", nil, nil)
tc, err := drv.OpenCtx(uint64(0), 0, uint16(mysql.DefaultCollationID), "", nil, nil)
require.NoError(t, err)
cc := &clientConn{
connectionID: 1,
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
// IDriver opens IContext.
type IDriver interface {
// OpenCtx opens an IContext with connection id, client capability, collation, dbname and optionally the tls state.
OpenCtx(connID uint64, capability uint32, collation uint8, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error)
OpenCtx(connID uint64, capability uint32, collation uint16, dbname string, tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error)
}

// PreparedStatement is the interface to use a prepared statement.
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/driver_tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,14 @@ func (ts *TiDBStatement) GetRowContainer() *chunk.RowContainer {
}

// OpenCtx implements IDriver.
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint8, _ string,
func (qd *TiDBDriver) OpenCtx(connID uint64, capability uint32, collation uint16, _ string,
tlsState *tls.ConnectionState, extensions *extension.SessionExtensions) (*TiDBContext, error) {
se, err := session.CreateSession(qd.store)
if err != nil {
return nil, err
}
se.SetTLSState(tlsState)
err = se.SetCollation(int(collation))
err = se.SetCollation(collation)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/internal/column/result_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (d *ResultEncoder) Clean() {

// UpdateDataEncoding updates the data encoding.
func (d *ResultEncoder) UpdateDataEncoding(chsID uint16) {
chs, _, err := charset.GetCharsetInfoByID(int(chsID))
chs, _, err := charset.GetCharsetInfoByID(chsID)
if err != nil {
logutil.BgLogger().Warn("unknown charset ID", zap.Error(err))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/internal/handshake/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ type Response41 struct {
Auth []byte
ZstdLevel int
Capability uint32
Collation uint8
Collation uint16
}
8 changes: 4 additions & 4 deletions pkg/server/internal/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ func HandshakeResponseHeader(ctx context.Context, packet *handshake.Response41,
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
packet.Collation = binary.LittleEndian.Uint16(data[offset : offset+2])
offset += 2
// skip reserved 22[00]
offset += 22

return offset, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/mock_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func CreateMockConn(t *testing.T, server *Server) MockConn {
require.NoError(t, err)

connID := rand.Uint64()
tc, err := server.driver.OpenCtx(connID, 0, uint8(tmysql.DefaultCollationID), "", nil, extensions.NewSessionExtensions())
tc, err := server.driver.OpenCtx(connID, 0, uint16(tmysql.DefaultCollationID), "", nil, extensions.NewSessionExtensions())
require.NoError(t, err)

cc := &clientConn{
Expand Down
2 changes: 1 addition & 1 deletion pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (s *session) SetCommandValue(command byte) {
atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command))
}

func (s *session) SetCollation(coID int) error {
func (s *session) SetCollation(coID uint16) error {
cs, co, err := charset.GetCharsetInfoByID(coID)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/session/types/sesson_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type Session interface {
SetCompressionLevel(int)
SetProcessInfo(string, time.Time, byte, uint64)
SetTLSState(*tls.ConnectionState)
SetCollation(coID int) error
SetCollation(coID uint16) error
SetSessionManager(util.SessionManager)
Close()
Auth(user *auth.UserIdentity, auth, salt []byte, authConn conn.AuthConn) error
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/mockcopr/cop_handler_dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ func extractOffsetsInExpr(expr *tipb.Expr, columns []*tipb.ColumnInfo, collector

// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo.
func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType {
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(uint16(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
ft := &types.FieldType{}
ft.SetType(byte(col.GetTp()))
ft.SetFlag(uint(col.GetFlag()))
Expand Down
2 changes: 1 addition & 1 deletion pkg/store/mockstore/unistore/cophandler/cop_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ func appendRow(chunks []tipb.Chunk, data []byte, rowCnt int) []tipb.Chunk {

// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo.
func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType {
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
charsetStr, collationStr, _ := charset.GetCharsetInfoByID(uint16(collate.RestoreCollationIDIfNeeded(col.GetCollation())))
ft := &types.FieldType{}
ft.SetType(byte(col.GetTp()))
ft.SetFlag(uint(col.GetFlag()))
Expand Down
16 changes: 8 additions & 8 deletions pkg/util/collate/collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

var (
newCollatorMap map[string]Collator
newCollatorIDMap map[int]Collator
newCollatorIDMap map[uint16]Collator
newCollationEnabled int32

// binCollatorInstance is a singleton used for all collations when newCollationEnabled is false.
Expand Down Expand Up @@ -172,13 +172,13 @@ func GetBinaryCollatorSlice(n int) []Collator {
}

// GetCollatorByID get the collator according to id, it will return the binary collator if the corresponding collator doesn't exist.
func GetCollatorByID(id int) Collator {
func GetCollatorByID(id uint16) Collator {
if atomic.LoadInt32(&newCollationEnabled) == 1 {
ctor, ok := newCollatorIDMap[id]
if !ok {
logutil.BgLogger().Warn(
"Unable to get collator by ID, use binCollator instead.",
zap.Int("ID", id),
zap.Uint16("ID", id),
zap.Stack("stack"))
return newCollatorMap["utf8mb4_bin"]
}
Expand All @@ -189,8 +189,8 @@ func GetCollatorByID(id int) Collator {

// CollationID2Name return the collation name by the given id.
// If the id is not found in the map, the default collation is returned.
func CollationID2Name(id int32) string {
collation, err := charset.GetCollationByID(int(id))
func CollationID2Name(id uint16) string {
collation, err := charset.GetCollationByID(uint16(id))
if err != nil {
// TODO(bb7133): fix repeating logs when the following code is uncommented.
// logutil.BgLogger().Warn(
Expand All @@ -204,7 +204,7 @@ func CollationID2Name(id int32) string {

// CollationName2ID return the collation id by the given name.
// If the name is not found in the map, the default collation id is returned
func CollationName2ID(name string) int {
func CollationName2ID(name string) uint16 {
if coll, err := charset.GetCollationByName(name); err == nil {
return coll.ID
}
Expand Down Expand Up @@ -380,7 +380,7 @@ func CollationToProto(c string) int32 {

// ProtoToCollation converts collation from int32(used by protocol) to string.
func ProtoToCollation(c int32) string {
coll, err := charset.GetCollationByID(int(RestoreCollationIDIfNeeded(c)))
coll, err := charset.GetCollationByID(uint16(RestoreCollationIDIfNeeded(c)))
if err == nil {
return coll.Name
}
Expand All @@ -398,7 +398,7 @@ func init() {
newCollationEnabled = 1

newCollatorMap = make(map[string]Collator)
newCollatorIDMap = make(map[int]Collator)
newCollatorIDMap = make(map[uint16]Collator)

newCollatorMap["binary"] = &binCollator{}
newCollatorIDMap[CollationName2ID("binary")] = &binCollator{}
Expand Down

0 comments on commit b1649e3

Please sign in to comment.